nni.nas.evaluator.evaluator 源代码

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

__all__ = ['Evaluator', 'MutableEvaluator', 'FrozenEvaluator']

import logging
from contextlib import contextmanager
from typing import Any, Callable, Iterator, Iterable, Generator, TYPE_CHECKING, cast
from typing_extensions import Literal

import nni
from nni.common.serializer import is_traceable, SerializableObject
from nni.mutable import Mutable, Sample, LabeledMutable, SampleValidationError
from nni.mutable.mutable import _mutable_equal
from nni.runtime.trial_command_channel import TrialCommandChannel, set_default_trial_command_channel, get_default_trial_command_channel
from nni.typehint import ParameterRecord, TrialMetric, Parameters

if TYPE_CHECKING:
    from nni.nas.space import ExecutableModelSpace

_logger = logging.getLogger(__name__)


[文档] class Evaluator: """Base class of evaluator. To users, the evaluator is to assess the quality of a model and return a score. When an evaluator is defined, it usually accepts a few arguments, such as basic runtime information (e.g., whether to use GPU), dataset used, as well as hyper-parameters (such as learning rate). These parameters can be sometimes tunable and searched by algorithms (see :class:`MutableEvaluator`). Different evaluators could have different use scenarios and requirements on the model. For example, :class:`~nni.nas.evaluator.pytorch.Classification` is tailored for classification models, and assumes the model has a ``forward`` method that takes a batch of data and returns logits. Evaluators might also have different assumptions, some of which are requirements of certain algorithms. The evaluator with the most freedom is :class:`~nni.nas.evaluator.FunctionalEvaluator`, but it's also incompatible with some algorithms. To developers, the evaluator is to implement *all the logics involving forward/backward of neural networks*. Sometimes the algorithm requires the training and searching at the same time (e.g., one-shot algos). In that case, although the searching part doesn't logically belong to the evaluator, it is still the evaluator's responsibility to implement it, and the search algorithms will make sure to properly manipulate the evaluator to achieve the goal. .. tip:: Inside evaluator, you can use standard :doc:`NNI trial APIs </reference/hpo>` to communicate with the exploration strategy. Common usages include: 1. Use :func:`nni.get_current_parameter` to get the current :class:`~nni.nas.space.ExecutableModelSpace`. Notice that :class:`~nni.nas.space.ExecutableModelSpace` is not a directly-runnable model (e.g., a PyTorch model), which is different from the model received in :meth:`evaluate`. :class:`~nni.nas.space.ExecutableModelSpace` objects are useful for debugging, as well as for some evaluators which need to know extra details about how the model is sampled. 2. Use :func:`nni.report_intermediate_result` to report intermediate results. 3. Use :func:`nni.report_final_result` to report final results. These APIs are only available when the evaluator is executed by NNI. We recommend using ``nni.get_current_parameter() is not None`` to check if the APIs are available before using them. Please AVOID using :func:`nni.get_next_parameter()` because NAS framework has already handled the logic of retrieving the next parameter. Incorrectly using :func:`nni.get_next_parameter()` may cause unexpected behavior. """
[文档] def evaluate(self, model: Any) -> Any: """To run evaluation of a model. The model is usually a concrete model. The return value of :meth:`evaluate` can be anything. Typically it's used for test purposes. Subclass should override this. """ raise NotImplementedError()
def __repr__(self): return f'{self.__class__.__name__}({self.extra_repr()})' def extra_repr(self) -> str: return ''
[文档] @staticmethod @contextmanager def mock_runtime(model: ExecutableModelSpace) -> Generator[None, None, None]: """ Context manager to mock trial APIs for standalone usage. Under the with-context of this method, :func:`nni.get_current_parameter` will return the given model. NOTE: This method might become a utility in trial command channel in future. Parameters ---------- model The model to be evaluated. It should be a :class:`~nni.nas.space.ExecutableModelSpace` object. Examples -------- This method should be mostly used when testing a evaluator. A typical use case is as follows: >>> frozen_model = model_space.freeze(sample) >>> with evaluator.mock_runtime(frozen_model): ... evaluator.evaluate(frozen_model.executable_model()) """ if nni.get_current_parameter() is not None: raise RuntimeError("Cannot mock trial when trial APIs are already available.") from nni.nas.space import ExecutableModelSpace if not isinstance(model, ExecutableModelSpace): raise TypeError("model should be an ExecutableModelSpace object.") trial_command_channel = get_default_trial_command_channel() original_params = nni.trial._params original_seq = nni.trial._intermediate_seq try: set_default_trial_command_channel(_EvaluatorMockTrialCommandChannel(model)) assert nni.get_next_parameter() is model yield finally: set_default_trial_command_channel(trial_command_channel) # NOTE: It might have some side effects on nni.trial._params. # Might cause the trial_available() to still return true after the context. # The following hack might address the problem. nni.trial._params = original_params nni.trial._intermediate_seq = original_seq
@staticmethod def _load(**ir: Any) -> 'Evaluator': """Subclass implements ``_load`` for their own serialization.""" raise NotImplementedError() def _dump(self) -> Any: """Subclass implements ``_dump`` for their own serialization.""" raise NotImplementedError() # def _execute(self, model: ExecutableModelSpace) -> Any: def _execute(self, model: Any) -> Any: """Advanced users can overwrite this to avoid instantiation of the deep learning model. For internal uses only. """ executable_model = model.executable_model() return self.evaluate(executable_model)
[文档] class MutableEvaluator(Mutable, Evaluator): """ Evaluators with tunable parameters by itself (e.g., learning rate). The tunable parameters must be an argument of the evaluator's instantiation, or an argument of the arguments' instantiation and etc. To use this feature, there are two requirements: 1. The evaluator must inherit :class:`MutableEvaluator` rather than :class:`Evaluator`. 2. Make sure the init arguments have been saved in ``trace_kwargs``, and the instance can be cloned with ``trace_copy``. The easiest way is to wrap the evaluator with :func:`nni.trace`. If the mutable parameter exists somewhere in the nested instantiation. All the levels must all be wrapped with :func:`nni.trace`. Examples -------- >>> def get_data(shuffle): ... ... >>> @nni.trace # 1. must wrap here ... class MyOwnEvaluator(MutableEvaluator): # 2. must inherit MutableEvaluator ... def __init__(self, lr, data): ... ... >>> evaluator = MyOwnEvaluator( ... lr=Categorical([0.1, 0.01]), # the argument can be tunable ... data=nni.trace(get_data)( # if there is mutable parameters inside, this must also have nni.trace ... shuffle=Categorical([False, True]) ... ) ... ) >>> evaluator.simplify() {'global/1': Categorical([0.1, 0.01], label='global/1'), 'global/2': Categorical([False, True], label='global/2')} """ @staticmethod def freeze_traceable_object(obj: Any, sample: dict[str, Any]) -> Any: # Could return: # 1. The same type of obj, if it is not traceable or nothing is changed. # 2. SerializableObject. If something is mutated, and the return value is merely an empty shell of parameters. if not is_traceable(obj, must_be_instance=True): return obj updates = {} # Iterate over all the arguments that have been used to instantiate the object for key, param in obj.trace_kwargs.items(): if isinstance(param, Mutable): updates[key] = param.freeze(sample) elif is_traceable(param): # Recursively sub_update = MutableEvaluator.freeze_traceable_object(param, sample) if sub_update is not param: # if mutated updates[key] = sub_update if updates: mutated_obj = obj.trace_copy() # Make a copy mutated_obj.trace_kwargs.update(updates) # Mutate # Should instantiate the full mutated object here. # But it's postponed to the next call of evaluate to save memory. # mutated_obj = mutated_obj.get() return mutated_obj return obj @staticmethod def expand_trace_kwargs(obj: Any) -> Iterator[Mutable]: if not is_traceable(obj, must_be_instance=True): raise TypeError(f'{obj} is not traceable.') for pos, param in enumerate(obj.trace_args): if isinstance(param, Mutable): raise ValueError(f'We currently do not support mutable parameters in positional arguments: {obj} (position: {pos})') for param in obj.trace_kwargs.values(): if isinstance(param, Mutable): yield param # Recursively yield all the nested parameters if is_traceable(param, must_be_instance=True): yield from MutableEvaluator.expand_trace_kwargs(param)
[文档] def freeze(self, sample: Sample) -> FrozenEvaluator | MutableEvaluator: """Upon freeze, :class:`MutableEvaluator` will freeze all the mutable parameters (as well as nested parameters), and return a :class:`FrozenEvaluator`. The evaluator will not be fully initialized to save the memory, especially when parameters contain large objects such as datasets. To use the evaluator, call :meth:`FrozenEvaluator.get` to get the full usable evaluator. Returns ------- The frozen evaluator. """ frozen_evaluator = self.freeze_traceable_object(self, sample) assert isinstance(frozen_evaluator, (MutableEvaluator, SerializableObject)), \ 'The evaluator must still be a MutableEvaluator or SerializableObject after freezing.' if type(frozen_evaluator) is SerializableObject: return FrozenEvaluator(frozen_evaluator) else: assert isinstance(frozen_evaluator, MutableEvaluator) return frozen_evaluator
def leaf_mutables(self, is_leaf: Callable[[Mutable], bool]) -> Iterable[LabeledMutable]: for param in self.expand_trace_kwargs(self): yield from param.leaf_mutables(is_leaf) def check_contains(self, sample: Sample) -> SampleValidationError | None: for mutable in self.expand_trace_kwargs(self): exception = mutable.check_contains(sample) if exception is not None: exception.paths.append(str(mutable)) return exception return None
[文档] def is_mutable(self) -> bool: """Whether some arguments of the evaluator are mutable. Although the evaluator is mutable, it may contain no mutable parameters, i.e., all its parameters (including nested ones) are fixed values. Return false if there is none. """ # It's a clever way to check whether an iterator is not empty, without expanding it. for _ in self.expand_trace_kwargs(self): return True return False
@classmethod def _load(cls, **ir: Any) -> MutableEvaluator: return cls(**ir) # type: ignore def _dump(self) -> dict: if not is_traceable(self): raise TypeError(f'{self} is not traceable.') return self.trace_kwargs # type: ignore def __eq__(self, other: Any) -> bool: if type(self) is not type(other): return False if not is_traceable(self): raise TypeError(f'{self} is not traceable.') return _mutable_equal(self.trace_kwargs, other.trace_kwargs) # type: ignore
class FrozenEvaluator(Evaluator): """:meth:`MutableEvaluator.freeze` returns a :class:`FrozenEvaluator`, which is purely an empty-shell (i.e., symbol and init parameters), and not instantiated. For the evaluator itself and its parameters, if its constructor is decorated with ``nni.trace``, in :attr:`frozen_obj`, it will be a :class:`SerializableObject` that contains the constructor (class / function) as well as the arguments that have been used to instantiate the object. When :meth:`evaluate` is invoked, it instantiates to the full evaluator recursively and execute it. """ def __init__(self, frozen_obj: SerializableObject): self.frozen_obj = frozen_obj self._instance: Evaluator | None = None @staticmethod def recursive_instantiate(obj: Any) -> Any: if type(obj) is not SerializableObject: # It has been instantiated return obj # obj must be a SerializableObject here. updates = {} for key, param in obj.trace_kwargs.items(): sub_update = FrozenEvaluator.recursive_instantiate(param) if sub_update is not param: updates[key] = sub_update if updates: obj.trace_kwargs.update(updates) # Calling get() of a SerializableObject gives a fully instantiated object. # We can't bypass this get() even if updates is empty because obj is never instantiated. return obj.get() def get(self) -> Evaluator: """Instantiate the full evaluator. The instantiated evaluator will be cached and reused next time. """ if self._instance is not None: return self._instance self._instance = self.recursive_instantiate(self.frozen_obj) assert isinstance(self._instance, Evaluator), f'Instantiated evaluator must be an Evaluator, got {type(self._instance)}' return self._instance def evaluate(self, model: Any) -> Any: """Calling :meth:`get` and execute the evaluator.""" return self.get().evaluate(model) def extra_repr(self) -> str: return f'frozen_obj={self.frozen_obj}, initialized={self._instance is not None}' @classmethod def _load(cls, **ir: Any) -> 'FrozenEvaluator': return ir['trace_symbol'](*ir['trace_args'], **ir['trace_kwargs']) def _dump(self) -> Any: return { 'trace_symbol': self.trace_symbol, 'trace_args': self.trace_args, 'trace_kwargs': self.trace_kwargs, } # The following APIs are provided as shortcut of frozen_obj. @property def trace_symbol(self): return self.frozen_obj.trace_symbol @property def trace_args(self): return self.frozen_obj.trace_args @property def trace_kwargs(self): return self.frozen_obj.trace_kwargs def trace_copy(self): return FrozenEvaluator(self.frozen_obj.trace_copy()) class _EvaluatorMockTrialCommandChannel(TrialCommandChannel): """Mock a trial command channel for evaluator debugging.""" def __init__(self, model: ExecutableModelSpace): self.model = model def receive_parameter(self) -> ParameterRecord | None: return { 'parameter_id': 0, 'parameters': cast(Parameters, self.model) } def send_metric(self, type: Literal['PERIODICAL', 'FINAL'], parameter_id: int | None, # pylint: disable=redefined-builtin trial_job_id: str, sequence: int, value: TrialMetric) -> None: if type == 'FINAL': self.model.metrics.final = value _logger.info('[Mock] Final metric: %s', value) else: self.model.metrics.add_intermediate(value) _logger.info('[Mock] Intermediate metric: %s', value)