nni.mutable.frozen 源代码

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

"""Utilities to freeze mutables upon its creation (either before or after),
such that when a proper context is provided, the mutables should look exactly the same as its frozen version.

__all__ = [
    'ensure_frozen', 'frozen_context', 'frozen_factory',

import logging
from contextlib import contextmanager
from typing import Any, Callable

from .mutable import Mutable, Sample
from .utils import NoContextError, ContextStack


_logger = logging.getLogger(__name__)

[文档]def ensure_frozen(mutable: Mutable | Any, *, strict: bool = True, sample: Sample | None = None, retries: int = 1000) -> Any: """Ensure a mutable is frozen. Used when passing the mutable to a function which doesn't accept a mutable. If the argument is not a mutable, nothing happens. Otherwise, :meth:`~nni.mutable.Mutable.freeze` will be called if sample is given. If sample is None, :func:`ensure_frozen` will also try to fill the sample with the content in :class:`frozen_context`. Or else :meth:`~nni.mutable.Mutable.robust_default` will be called on the mutable. Parameters ---------- mutable : nni.mutable.Mutable or any The mutable to freeze. strict Whether to raise an error if sample context is not provided and not found. sample The context to freeze the mutable with. retries Control the number of retries in case :meth:`~nni.mutable.Mutable.robust_default` is called. Examples -------- >>> with frozen_context({'a': 2}): ... ensure_frozen(Categorical([1, 2, 3], label='a')) 2 >>> ensure_frozen(Categorical([1, 2, 3]), strict=False) 1 >>> ensure_frozen(Categorical([1, 2, 3], label='a'), sample={'a': 2}, strict=False) 2 >>> ensure_frozen('anything', strict=False) 'anything' """ if not isinstance(mutable, Mutable): return mutable # If we're in a frozen context, just use the current. if sample is None: ctx = frozen_context.current() if ctx is not None: sample = ctx if sample is not None: # If we have a sample, we can just use it. # Use freeze here to detect potential label mismatch errors. try: return mutable.freeze(sample) except: _logger.error( 'Failed to freeze mutable %s with sample %s. ' 'In NAS, please make sure to have registered it via add_mutable(). ' 'Otherwise, please make sure you are not inside a frozen_context.', mutable, sample) raise else: if retries < 0 or (_ENSURE_FROZEN_STRICT and strict): raise RuntimeError( f'No frozen context is found for {mutable!r}. Assuming no context. ' 'If you are using NAS, you are probably using `ensure_frozen` in forward, or outside the init of ModelSpace. ' 'Please avoid doing this as they will lead to erroneous results.' ) # TODO: Currently only mutable parameters in NAS evaluator end up here. # It might cause consistency issues between multiple parameters without context. # I don't want to throw a warning here, but there should be a smarter way to do this. return mutable.robust_default(retries=retries)
[文档]class frozen_context(ContextStack): """ Context manager to set a sample into context. Then the sample will be retrievable from an arbitrary level of function calls via :func:`current_frozen_context`. There are two use cases: 1. Setting a global sample so that some modules can directly create the frozen version, rather than first-create-and-freeze. 2. Sharing default / dry-run samples when the search space is dynamically created. The implementation is basically adding another layer of empty dict on top of a global stack. When retrieved, all dicts in the stack will be merged, from the bottom to the top. When updated, only the dict on the top will be updated. Parameters ---------- sample The sample to be set into context. Returns ------- Context manager that provides a frozen context. Examples -------- :: def some_func(): print(frozen_context.current()['learning_rate']) # 0.1 with frozen_context({'learning_rate': 0.1}): some_func() """ def __init__(self, sample: Sample | None = None): super().__init__(_FROZEN_CONTEXT_KEY, sample or {}) @staticmethod def top_context() -> frozen_context: return ContextStack.top(_FROZEN_CONTEXT_KEY)
[文档] @staticmethod def current() -> dict | None: """Retrieve the current frozen context. If multiple layers have been found, they would be merged from bottom to top. Returns ------- The sample in frozen context. If no sample is found, return none. """ try: ContextStack.top(_FROZEN_CONTEXT_KEY) sample: Sample = {} for ctx in ContextStack.stack(_FROZEN_CONTEXT_KEY): if not isinstance(ctx, dict): raise TypeError(f'Expect architecture to be a dict, found: {ctx}') sample.update(ctx) return sample except NoContextError: return None
[文档] @staticmethod def update(sample: Sample) -> None: """ Update the current dry run context. Only the topmost context will be updated. Parameters ---------- sample The sample to be updated into context. """ try: ctx = ContextStack.top(_FROZEN_CONTEXT_KEY) assert isinstance(ctx, dict) ctx.update(sample) except NoContextError: raise RuntimeError('No frozen context is found. Please use frozen_context() to create one.')
[文档] @staticmethod @contextmanager def bypass(): """ Ignore the most recent :class:`frozen_context`. This is useful in creating a search space within a ``frozen_context()`` context. Under the hood, it only disables the most recent one frozen context, which means, if it's currently in a nested with-frozen-arch context, multiple ``bypass()`` contexts is required. Examples -------- >>> with frozen_context(arch_dict): ... with frozen_context.bypass(): ... model_space = ModelSpace() """ NO_CONTEXT = '_no_ctx_' sample = NO_CONTEXT # make linter happy try: try: sample = ContextStack.pop(_FROZEN_CONTEXT_KEY) except IndexError: # context unavailable sample = NO_CONTEXT yield finally: if sample is not NO_CONTEXT: ContextStack.push(_FROZEN_CONTEXT_KEY, sample)
[文档]class frozen_factory: """Create a factory object that invokes a function with a frozen context. Parameters ---------- callable The function to be invoked. sample The sample to be used as the frozen context. Examples -------- >>> factory = frozen_factory(ModelSpaceClass, {"choice1": 3}) >>> model = factory(channels=16, classes=10) """ # NOTE: mutations on ``init_args`` and ``init_kwargs`` themselves are not supported. def __init__(self, callable: Callable[..., Any], sample: Sample | frozen_context): # pylint: disable=redefined-builtin self.callable = callable if not isinstance(sample, frozen_context): self.sample = frozen_context(sample) else: self.sample = sample def __call__(self, *init_args, **init_kwargs): with self.sample: return self.callable(*init_args, **init_kwargs) def __repr__(self): return f'frozen_factory(callable={self.callable}, arch={self.sample.value})'