Source code for nni.nas.mutable.mutator

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import warnings
from typing import (Any, Iterable, List, Optional, Tuple, cast)

from nni.nas.execution import Model, Mutation, ModelStatus

__all__ = ['Sampler', 'Mutator', 'InvalidMutation']

Choice = Any

[docs]class Sampler: """ Handles `Mutator.choice()` calls. """ def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice: raise NotImplementedError() def mutation_start(self, mutator: 'Mutator', model: Model) -> None: pass def mutation_end(self, mutator: 'Mutator', model: Model) -> None: pass
[docs]class Mutator: """ Mutates graphs in model to generate new model. `Mutator` class will be used in two places: 1. Inherit `Mutator` to implement graph mutation logic. 2. Use `Mutator` subclass to implement NAS strategy. In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`. In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass, and then use `Mutator.apply()` to mutate model. For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates. # Method names are open for discussion. If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label. """ def __init__(self, sampler: Optional[Sampler] = None, label: str = cast(str, None)): self.sampler: Optional[Sampler] = sampler if label is None: warnings.warn('Each mutator should have an explicit label. Mutator without label is deprecated.', DeprecationWarning) self.label: str = label self._cur_model: Optional[Model] = None self._cur_choice_idx: Optional[int] = None
[docs] def bind_sampler(self, sampler: Sampler) -> 'Mutator': """ Set the sampler which will handle `Mutator.choice` calls. """ self.sampler = sampler return self
[docs] def apply(self, model: Model) -> Model: """ Apply this mutator on a model. Returns mutated model. The model will be copied before mutation and the original model will not be modified. """ assert self.sampler is not None copy = model.fork() self._cur_model = copy self._cur_choice_idx = 0 self._cur_samples = [] self.sampler.mutation_start(self, copy) self.mutate(copy) self.sampler.mutation_end(self, copy) copy.history.append(Mutation(self, self._cur_samples, model, copy)) copy.status = ModelStatus.Frozen self._cur_model = None self._cur_choice_idx = None return copy
[docs] def dry_run(self, model: Model) -> Tuple[List[List[Choice]], Model]: """ Dry run mutator on a model to collect choice candidates. If you invoke this method multiple times on same or different models, it may or may not return identical results, depending on how the subclass implements `Mutator.mutate()`. """ sampler_backup = self.sampler recorder = _RecorderSampler() self.sampler = recorder new_model = self.apply(model) self.sampler = sampler_backup return recorder.recorded_candidates, new_model
[docs] def mutate(self, model: Model) -> None: """ Abstract method to be implemented by subclass. Mutate a model in place. """ raise NotImplementedError()
[docs] def choice(self, candidates: Iterable[Choice]) -> Choice: """ Ask sampler to make a choice. """ assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx) self._cur_samples.append(ret) self._cur_choice_idx += 1 return ret
class _RecorderSampler(Sampler): def __init__(self): self.recorded_candidates: List[List[Choice]] = [] def choice(self, candidates: List[Choice], *args) -> Choice: self.recorded_candidates.append(candidates) return candidates[0]
[docs]class InvalidMutation(Exception): pass