# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import atexit
from enum import Enum
import logging
from pathlib import Path
import socket
from subprocess import Popen
import time
from typing import Any, cast
import psutil
from typing_extensions import Literal
from nni.runtime.log import start_experiment_logging, stop_experiment_logging
from nni.tools.nnictl.config_utils import Experiments
from .config import ExperimentConfig
from .data import TrialJob, TrialMetricData, TrialResult
from . import launcher
from . import management
from . import rest
from ..tools.nnictl.command_utils import kill_command
_logger = logging.getLogger('nni.experiment')
class RunMode(Enum):
"""
Config lifecycle and ouput redirection of NNI manager process.
- Background: stop NNI manager when Python script exits; do not print NNI manager log. (default)
- Foreground: stop NNI manager when Python script exits; print NNI manager log to stdout.
- Detach: do not stop NNI manager when Python script exits.
NOTE: This API is non-stable and is likely to get refactored in upcoming release.
"""
# TODO:
# NNI manager should treat log level more seriously so we can default to "foreground" without being too verbose.
Background = 'background'
Foreground = 'foreground'
Detach = 'detach'
[文档]
class Experiment:
"""
Manage NNI experiment.
You can either specify an :class:`ExperimentConfig` object, or a training service name.
If a platform name is used, a blank config template for that training service will be generated.
When configuration is completed, use :meth:`Experiment.run` to launch the experiment.
Parameters
----------
config_or_platform
See :class:`~nni.experiment.config.ExperimentConfig`.
id
Experiment ID. If not specified, a random ID will be generated.
If specified, the ID should be unique to avoid conflict with existing experiments.
The only case when you need to specify an existing ID is when you want to resume an experiment.
Example
-------
.. code-block::
experiment = Experiment('remote')
experiment.config.trial_command = 'python3 trial.py'
experiment.config.machines.append(RemoteMachineConfig(ip=..., user_name=...))
...
experiment.run(8080)
Attributes
----------
config
Experiment configuration.
id
Experiment ID.
port
Web portal port. Or ``None`` if the experiment is not running.
"""
def __init__(
self,
config_or_platform: ExperimentConfig | str | list[str] | None,
id: str | None = None # pylint: disable=redefined-builtin
):
self.config: ExperimentConfig | None = None
if id is not None:
if not management.is_valid_experiment_id(id):
raise ValueError(f'Invalid experiment ID: {id}. Experiment ID should only contain digits, alphanumeric characters, '
'hyphens, and underscores, and should be no longer than 32 characters.')
self.id = id
else:
self.id = management.generate_experiment_id()
self.port: int | None = None
self._proc: Popen | psutil.Process | None = None
self._action: Literal['create', 'resume', 'view'] = 'create'
self.url_prefix: str | None = None
if isinstance(config_or_platform, (str, list)):
self.config = ExperimentConfig(config_or_platform)
else:
self.config = config_or_platform
def _start_logging(self, debug: bool) -> None:
assert self.config is not None
config = self.config.canonical_copy()
log_file = Path(config.experiment_working_directory, self.id, 'log', 'experiment.log')
log_file.parent.mkdir(parents=True, exist_ok=True)
log_level = 'debug' if (debug or config.log_level == 'trace') else config.log_level
start_experiment_logging(self.id, log_file, cast(str, log_level))
def _start_nni_manager(self, port: int, debug: bool, run_mode: RunMode = RunMode.Background,
tuner_command_channel: str | None = None,
tags: list[str] = []) -> None:
assert self.config is not None
config = self.config.canonical_copy()
if config.use_annotation:
raise RuntimeError('NNI annotation is not supported by Python experiment API.')
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode,
self.url_prefix, tuner_command_channel, tags)
assert self._proc is not None
self.port = port # port will be None if start up failed
ips = [config.nni_manager_ip]
for interfaces in psutil.net_if_addrs().values():
for interface in interfaces:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web portal URLs: ${CYAN}' + ' '.join(ips)
_logger.info(msg)
[文档]
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
run_mode
Running the experiment in foreground or background
"""
if run_mode is not RunMode.Detach:
# If the experiment ends normally without KeyboardInterrupt, stop won't be automatically called.
# As a result, NNI manager will continue to run in the background, even after run() exits.
# To kill it, either call stop() manually, or atexit will clean it up at process exit.
atexit.register(self.stop)
self._start_logging(debug)
self._start_nni_manager(port, debug, run_mode, None, [])
def _stop_logging(self) -> None:
stop_experiment_logging(self.id)
def _stop_nni_manager(self) -> None:
if self._proc is not None:
try:
rest.delete(self.port, '/experiment', self.url_prefix)
self._proc.wait()
except Exception as e:
_logger.exception(e)
_logger.warning('Cannot gracefully stop experiment, killing NNI process...')
kill_command(self._proc.pid)
self.port = None
self._proc = None
[文档]
def stop(self) -> None:
"""
Stop the experiment.
"""
_logger.info('Stopping experiment, please wait...')
atexit.unregister(self.stop)
_logger.info('Saving experiment checkpoint...')
self.save_checkpoint()
_logger.info('Stopping NNI manager, if any...')
self._stop_nni_manager()
self._stop_logging()
_logger.info('Experiment stopped.')
def _wait_completion(self) -> bool:
while True:
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
time.sleep(10)
def _run_impl(self, port: int, wait_completion: bool, debug: bool) -> bool | None:
try:
self.start(port, debug)
if wait_completion:
return self._wait_completion()
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
self.stop()
# NOTE: stop is not called if wait is successful without interrupt.
return None
[文档]
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
"""
Run the experiment.
Using Ctrl-C will :meth:`stop` the experiment.
Otherwise the experiment won't be :meth:`stop`ped even if the method returns.
It has to be manually :meth:`stop`ped, or atexit will :meth:`stop` it at process exit.
Parameters
----------
port
The port on which NNI manager will run. It will also be the port of web portal.
wait_completion
If ``wait_completion`` is ``True``, this function will block until experiment finish or error.
debug
Set log level to debug.
Returns
-------
If ``wait_completion`` is ``False``, this function will non-block and return None immediately.
Otherwise, return ``True`` when experiment done; or return ``False`` when experiment failed.
"""
return self._run_impl(port, wait_completion, debug)
[文档]
def run_or_resume(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
"""
Call :meth:`run` or :meth:`resume` based on the return value of :meth:`has_checkpoint`.
Parameters are return values are same as :meth:`run`.
"""
if self.has_checkpoint():
_logger.info('Checkpoint is found. Resume the experiment: %s', self.id)
return self.resume(port, wait_completion, debug)
else:
_logger.info('No checkpoint is found. Start a new experiment: %s', self.id)
return self.run(port, wait_completion, debug)
[文档]
def has_checkpoint(self) -> bool:
"""
Check whether a checkpoint of current experiment ID exists.
Returns
-------
``True`` if checkpoint is found; ``False`` otherwise.
"""
# First check whether a checkpoint exists.
experiments_dict = Experiments().get_all_experiments()
if self.id in experiments_dict:
_logger.debug('Checkpoint is found in experiment manifest. The experiment can be resumed: %r', experiments_dict[self.id])
return True
else:
_logger.debug('No checkpoint with %s is found in experiment manifest.', self.id)
return False
[文档]
def load_checkpoint(self) -> None:
"""
Load checkpoint from local file system.
Restores the status of the experiment instance.
"""
# HPO basically only needs to load the config.
# In case the current experiment already has a config,
# respect the new config's working directory.
if self.config is not None:
experiment_working_directory = self.config.canonical_copy().experiment_working_directory
else:
experiment_working_directory = None
# Load the config regardless of whether current config is provided or not.
config = launcher.get_stopped_experiment_config(self.id, exp_dir=experiment_working_directory)
if self.config is not None:
# If `self.config` is set, do some validation.
from .config.utils import diff
config_diff = diff(self.config, config, 'Current', 'Loaded')
if config_diff:
_logger.warning('Config is found but does not match the current config:\n%s', config_diff)
_logger.warning('Current config will NOT be overridden by the loaded config.')
else:
_logger.info('Current config matches the loaded config.')
else:
# If `self.config` is not set, use the loaded config.
_logger.debug('Current config is None. Loaded config will be used: %r', config)
self.config = config
[文档]
def save_checkpoint(self) -> None:
"""
Save the experiment status to local file system.
"""
# HPO experiment doesn't need to do this because the state has already been saved by underlying components.
pass
[文档]
@classmethod
def connect(cls, port: int):
"""
Connect to an existing experiment.
Parameters
----------
port
The port of web UI.
"""
experiment = cls(None)
experiment.port = port
experiment.id = experiment.get_experiment_profile().get('id')
status = experiment.get_status()
pid = experiment.get_experiment_metadata(experiment.id).get('pid')
if pid is None:
_logger.warning('Get experiment pid failed, can not stop experiment by stop().')
else:
experiment._proc = psutil.Process(pid)
_logger.info('Connect to port %d success, experiment id is %s, status is %s.', port, experiment.id, status)
return experiment
[文档]
def resume(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
"""
Resume a stopped experiment.
Parameters
----------
port
The port of web UI.
wait_completion
If true, run in the foreground. If false, run in the background.
debug
Whether to start in debug mode.
Returns
-------
See :meth:`run`.
"""
# Backward compatibility:
# We will stop supporting experiment_id as keyword arguments instantly right now,
# because keeping it compatible will be very tricky and not worth the effort.
# But experiment_id as positional argument is still supported for now.
if isinstance(self, str):
_logger.warning('Experiment.resume(id) is deprecated (and has already stopped working for non-HPO experiments). '
'Use Experiment(id).resume() instead.')
# Assumes the type is `Experiment`, self is experiment_id.
self = Experiment(None, id=self)
if not self.has_checkpoint():
raise RuntimeError(f'Experiment {self.id} does not exist thus cannot be resumed.')
self.load_checkpoint()
self._action = 'resume'
return self._run_impl(port, wait_completion, debug)
[文档]
def view(self, port: int = 8080, non_blocking: bool = False) -> Experiment:
"""
View a stopped experiment.
Parameters
----------
port
The port of web UI.
non_blocking
If false, run in the foreground. If true, run in the background.
Returns
-------
Return self instance.
"""
# Backward compatibility
if isinstance(self, str):
_logger.warning('Experiment.view(id) is deprecated (and has already stopped working for non-HPO experiments). '
'Use Experiment(id).view() instead.')
# Assumes the type is `Experiment`, self is experiment_id.
self = Experiment(None, id=self)
self._action = 'view'
if not self.has_checkpoint():
raise RuntimeError(f'Experiment {self.id} does not exist thus cannot be viewed.')
self.load_checkpoint()
self.start(port=port, debug=False, run_mode=RunMode.Detach)
if not non_blocking:
try:
while True:
time.sleep(10)
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
return self
[文档]
def get_status(self) -> str:
"""
Return experiment status as a str.
Returns
-------
str
Experiment status.
"""
resp = rest.get(self.port, '/check-status', self.url_prefix)
return resp['status']
[文档]
def get_trial_job(self, trial_job_id: str):
"""
Return a trial job.
Parameters
----------
trial_job_id: str
Trial job id.
Returns
-------
TrialJob
A `TrialJob` instance corresponding to `trial_job_id`.
"""
resp = rest.get(self.port, '/trial-jobs/{}'.format(trial_job_id), self.url_prefix)
return TrialJob(**resp)
[文档]
def list_trial_jobs(self):
"""
Return information for all trial jobs as a list.
Returns
-------
list
List of `TrialJob`.
"""
resp = rest.get(self.port, '/trial-jobs', self.url_prefix)
return [TrialJob(**trial_job) for trial_job in resp]
[文档]
def get_job_statistics(self):
"""
Return trial job statistics information as a dict.
Returns
-------
dict
Job statistics information.
"""
resp = rest.get(self.port, '/job-statistics', self.url_prefix)
return resp
[文档]
def get_job_metrics(self, trial_job_id=None):
"""
Return trial job metrics.
Parameters
----------
trial_job_id: str
trial job id. if this parameter is None, all trail jobs' metrics will be returned.
Returns
-------
dict
Each key is a trialJobId, the corresponding value is a list of `TrialMetricData`.
"""
api = '/metric-data/{}'.format(trial_job_id) if trial_job_id else '/metric-data'
resp = rest.get(self.port, api, self.url_prefix)
metric_dict = {}
for metric in resp:
trial_id = metric["trialJobId"]
if trial_id not in metric_dict:
metric_dict[trial_id] = [TrialMetricData(**metric)]
else:
metric_dict[trial_id].append(TrialMetricData(**metric))
return metric_dict
[文档]
def get_experiment_profile(self):
"""
Return experiment profile as a dict.
Returns
-------
dict
The profile of the experiment.
"""
resp = rest.get(self.port, '/experiment', self.url_prefix)
return resp
[文档]
def export_data(self):
"""
Return exported information for all trial jobs.
Returns
-------
list
List of `TrialResult`.
"""
resp = rest.get(self.port, '/export-data', self.url_prefix)
return [TrialResult(**trial_result) for trial_result in resp]
def _get_query_type(self, key: str):
if key == 'trialConcurrency':
return '?update_type=TRIAL_CONCURRENCY'
if key == 'maxExperimentDuration':
return '?update_type=MAX_EXEC_DURATION'
if key == 'searchSpace':
return '?update_type=SEARCH_SPACE'
if key == 'maxTrialNumber':
return '?update_type=MAX_TRIAL_NUM'
def _update_experiment_profile(self, key: str, value: Any):
"""
Update an experiment's profile
Parameters
----------
key: str
One of `['trial_concurrency', 'max_experiment_duration', 'search_space', 'max_trial_number']`.
value: Any
New value of the key.
"""
api = '/experiment{}'.format(self._get_query_type(key))
experiment_profile = self.get_experiment_profile()
experiment_profile['params'][key] = value
rest.put(self.port, api, experiment_profile, self.url_prefix)
_logger.info('Successfully update %s.', key)
[文档]
def update_trial_concurrency(self, value: int):
"""
Update an experiment's trial_concurrency
Parameters
----------
value: int
New trial_concurrency value.
"""
self._update_experiment_profile('trialConcurrency', value)
[文档]
def update_max_experiment_duration(self, value: str):
"""
Update an experiment's max_experiment_duration
Parameters
----------
value: str
Strings like '1m' for one minute or '2h' for two hours.
SUFFIX may be 's' for seconds, 'm' for minutes, 'h' for hours or 'd' for days.
"""
self._update_experiment_profile('maxExperimentDuration', value)
[文档]
def update_search_space(self, value: dict):
"""
Update the experiment's search_space.
TODO: support searchspace file.
Parameters
----------
value: dict
New search_space.
"""
self._update_experiment_profile('searchSpace', value)
[文档]
def update_max_trial_number(self, value: int):
"""
Update an experiment's max_trial_number
Parameters
----------
value: int
New max_trial_number value.
"""
self._update_experiment_profile('maxTrialNumber', value)
[文档]
def kill_trial_job(self, trial_job_id: str):
"""
Kill a trial job.
Parameters
----------
trial_job_id: str
Trial job id.
"""
rest.delete(self.port, '/trial-jobs/{}'.format(trial_job_id), self.url_prefix)