nni.nas.space.frozen 源代码
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['current_model', 'model_context']
import copy
from typing import Optional
from nni.mutable import frozen_context, Sample
[文档]
def current_model() -> Optional[Sample]:
"""Get the current model sample in :func:`model_context`.
The sample is supposed to be the same as :attr:`nni.nas.space.ExecutableModelSpace.sample`.
This method is only valid when called inside :func:`model_context`.
By default, only the execution of :class:`~nni.nas.space.SimplifiedModelSpace` will set the context,
so that :func:`current_model` is meaningful within the re-instantiation of the model.
Returns
-------
Model sample (i.e., architecture dict) before freezing, produced by strategy.
If not called inside :func:`model_context`, returns None.
"""
cur = frozen_context.current()
if cur is None or not cur.get('__arch__'):
# frozen_context exists but it's not set by arch.
return None
cur = copy.copy(cur)
cur.pop('__arch__')
return cur
[文档]
def model_context(sample: Sample) -> frozen_context:
"""Get a context stack of the current model sample (i.e., architecture dict).
This should be used together with :func:`current_model`.
:func:`model_context` is read-only, and should not be used to modify the architecture dict.
"""
return frozen_context({**sample, '__arch__': True})