Source code for nni.experiment.experiment

import atexit
import logging
from pathlib import Path
import socket
from subprocess import Popen
import time
from typing import Optional, Union, List, overload, Any

import json_tricks
import colorama
import psutil

import nni.runtime.log

from .config import ExperimentConfig, AlgorithmConfig
from .data import TrialJob, TrialMetricData, TrialResult
from . import launcher
from . import management
from . import rest
from ..tools.nnictl.command_utils import kill_command

nni.runtime.log.init_logger_experiment()
_logger = logging.getLogger('nni.experiment')


[docs]class Experiment: """ Create and stop an NNI experiment. Attributes ---------- config Experiment configuration. port Web UI port of the experiment, or `None` if it is not running. """ @overload def __init__(self, config: ExperimentConfig) -> None: """ Prepare an experiment. Use `Experiment.start()` to launch it. Parameters ---------- config Experiment configuration. """ ... @overload def __init__(self, training_service: Union[str, List[str]]) -> None: """ Prepare an experiment, leaving configuration fields to be set later. Example usage:: experiment = Experiment('remote') experiment.config.trial_command = 'python3 trial.py' experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...)) ... experiment.start(8080) Parameters ---------- training_service Name of training service. Supported value: "local", "remote", "openpai", "aml", "kubeflow", "frameworkcontroller", "adl" and hybrid training service. """ ... def __init__(self, config=None, training_service=None): self.config: Optional[ExperimentConfig] = None self.id: Optional[str] = None self.port: Optional[int] = None self._proc: Optional[Popen] = None args = [config, training_service] # deal with overloading if isinstance(args[0], (str, list)): self.config = ExperimentConfig(args[0]) self.config.tuner = AlgorithmConfig(name='_none_', class_args={}) self.config.assessor = AlgorithmConfig(name='_none_', class_args={}) self.config.advisor = AlgorithmConfig(name='_none_', class_args={}) else: self.config = args[0]
[docs] def start(self, port: int = 8080, debug: bool = False) -> None: """ Start the experiment in background. This method will raise exception on failure. If it returns, the experiment should have been successfully started. Parameters ---------- port The port of web UI. debug Whether to start in debug mode. """ atexit.register(self.stop) self.id = management.generate_experiment_id() if self.config.experiment_working_directory is not None: log_dir = Path(self.config.experiment_working_directory, self.id, 'log') else: log_dir = Path.home() / f'nni-experiments/{self.id}/log' nni.runtime.log.start_experiment_log(self.id, log_dir, debug) self._proc = launcher.start_experiment(self.id, self.config, port, debug) assert self._proc is not None self.port = port # port will be None if start up failed ips = [self.config.nni_manager_ip] for interfaces in psutil.net_if_addrs().values(): for interface in interfaces: if interface.family == socket.AF_INET: ips.append(interface.address) ips = [f'http://{ip}:{port}' for ip in ips if ip] msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL _logger.info(msg)
[docs] def stop(self) -> None: """ Stop background experiment. """ _logger.info('Stopping experiment, please wait...') atexit.unregister(self.stop) if self.id is not None: nni.runtime.log.stop_experiment_log(self.id) if self._proc is not None: try: rest.delete(self.port, '/experiment') except Exception as e: _logger.exception(e) _logger.warning('Cannot gracefully stop experiment, killing NNI process...') kill_command(self._proc.pid) self.id = None self.port = None self._proc = None _logger.info('Experiment stopped')
[docs] def run(self, port: int = 8080, debug: bool = False) -> bool: """ Run the experiment. This function will block until experiment finish or error. Return `True` when experiment done; or return `False` when experiment failed. """ self.start(port, debug) try: while True: time.sleep(10) status = self.get_status() if status == 'DONE' or status == 'STOPPED': return True if status == 'ERROR': return False except KeyboardInterrupt: _logger.warning('KeyboardInterrupt detected') finally: self.stop()
[docs] @classmethod def connect(cls, port: int): """ Connect to an existing experiment. Parameters ---------- port The port of web UI. """ experiment = Experiment() experiment.port = port experiment.id = experiment.get_experiment_profile().get('id') status = experiment.get_status() pid = experiment.get_experiment_metadata(experiment.id).get('pid') if pid is None: _logger.warning('Get experiment pid failed, can not stop experiment by stop().') else: experiment._proc = psutil.Process(pid) _logger.info('Connect to port %d success, experiment id is %s, status is %s.', port, experiment.id, status) return experiment
[docs] def get_status(self) -> str: """ Return experiment status as a str. Returns ------- str Experiment status. """ resp = rest.get(self.port, '/check-status') return resp['status']
[docs] def get_trial_job(self, trial_job_id: str): """ Return a trial job. Parameters ---------- trial_job_id: str Trial job id. Returns ------- TrialJob A `TrialJob` instance corresponding to `trial_job_id`. """ resp = rest.get(self.port, '/trial-jobs/{}'.format(trial_job_id)) return TrialJob(**resp)
[docs] def list_trial_jobs(self): """ Return information for all trial jobs as a list. Returns ------- list List of `TrialJob`. """ resp = rest.get(self.port, '/trial-jobs') return [TrialJob(**trial_job) for trial_job in resp]
[docs] def get_job_statistics(self): """ Return trial job statistics information as a dict. Returns ------- dict Job statistics information. """ resp = rest.get(self.port, '/job-statistics') return resp
[docs] def get_job_metrics(self, trial_job_id=None): """ Return trial job metrics. Parameters ---------- trial_job_id: str trial job id. if this parameter is None, all trail jobs' metrics will be returned. Returns ------- dict Each key is a trialJobId, the corresponding value is a list of `TrialMetricData`. """ api = '/metric-data/{}'.format(trial_job_id) if trial_job_id else '/metric-data' resp = rest.get(self.port, api) metric_dict = {} for metric in resp: trial_id = metric["trialJobId"] if trial_id not in metric_dict: metric_dict[trial_id] = [TrialMetricData(**metric)] else: metric_dict[trial_id].append(TrialMetricData(**metric)) return metric_dict
[docs] def get_experiment_profile(self): """ Return experiment profile as a dict. Returns ------- dict The profile of the experiment. """ resp = rest.get(self.port, '/experiment') return resp
[docs] def get_experiment_metadata(self, exp_id: str): """ Return experiment metadata with specified exp_id as a dict. Returns ------- dict The specified experiment metadata. """ experiments_metadata = self.get_all_experiments_metadata() for metadata in experiments_metadata: if metadata['id'] == exp_id: return metadata return {}
[docs] def get_all_experiments_metadata(self): """ Return all experiments metadata as a list. Returns ------- list The experiments metadata. """ resp = rest.get(self.port, '/experiments-info') return resp
[docs] def export_data(self): """ Return exported information for all trial jobs. Returns ------- list List of `TrialResult`. """ resp = rest.get(self.port, '/export-data') return [TrialResult(**trial_result) for trial_result in resp]
def _get_query_type(self, key: str): if key == 'trialConcurrency': return '?update_type=TRIAL_CONCURRENCY' if key == 'maxExecDuration': return '?update_type=MAX_EXEC_DURATION' if key == 'searchSpace': return '?update_type=SEARCH_SPACE' if key == 'maxTrialNum': return '?update_type=MAX_TRIAL_NUM' def _update_experiment_profile(self, key: str, value: Any): """ Update an experiment's profile Parameters ---------- key: str One of `['trial_concurrency', 'max_experiment_duration', 'search_space', 'max_trial_number']`. value: Any New value of the key. """ api = '/experiment{}'.format(self._get_query_type(key)) experiment_profile = self.get_experiment_profile() experiment_profile['params'][key] = value rest.put(self.port, api, experiment_profile) logging.info('Successfully update %s.', key)
[docs] def update_trial_concurrency(self, value: int): """ Update an experiment's trial_concurrency Parameters ---------- value: int New trial_concurrency value. """ self._update_experiment_profile('trialConcurrency', value)
[docs] def update_max_experiment_duration(self, value: str): """ Update an experiment's max_experiment_duration Parameters ---------- value: str Strings like '1m' for one minute or '2h' for two hours. SUFFIX may be 's' for seconds, 'm' for minutes, 'h' for hours or 'd' for days. """ self._update_experiment_profile('maxExecDuration', value)
[docs] def update_search_space(self, value: dict): """ Update the experiment's search_space. TODO: support searchspace file. Parameters ---------- value: dict New search_space. """ value = json_tricks.dumps(value) self._update_experiment_profile('searchSpace', value)
[docs] def update_max_trial_number(self, value: int): """ Update an experiment's max_trial_number Parameters ---------- value: int New max_trial_number value. """ self._update_experiment_profile('maxTrialNum', value)