# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import inspect
import os
import warnings
from typing import Any, TypeVar, Type
from nni.common.serializer import is_traceable, is_wrapped_with_trace, trace, _copy_class_wrapper_attributes
from .misc import ModelNamespace
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
'is_basic_unit', 'is_model_wrapped']
T = TypeVar('T')
def get_init_parameters_or_fail(obj: Any):
if is_traceable(obj):
return obj.trace_kwargs
raise ValueError(f'Object {obj} needs to be serializable but `trace_kwargs` is not available. '
'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
'If it is a customized module, please to decorate it with @basic_unit. '
'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
'try to use @nni.trace.')
def serialize(cls, *args, **kwargs):
"""
To create an serializable instance inline without decorator. For example,
.. code-block:: python
self.op = serialize(MyCustomOp, hidden_units=128)
"""
warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
'Try to use nni.trace, e.g., nni.trace(torch.optim.Adam)(learning_rate=1e-4) instead.',
category=DeprecationWarning)
return trace(cls)(*args, **kwargs)
def serialize_cls(cls):
"""
To create an serializable class.
"""
warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
'Try to use nni.trace instead.', category=DeprecationWarning)
return trace(cls)
[docs]def basic_unit(cls: T, basic_unit_tag: bool = True) -> T:
"""
To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
``basic_unit_tag`` is true by default. If set to false, it will not be explicitly mark as a basic unit, and
graph parser will continue to parse. Currently, this is to handle a special case in ``nn.Sequential``.
Although ``basic_unit`` calls ``trace`` in its implementation, it is not for serialization. Rather, it is meant
to capture the initialization arguments for mutation. Also, graph execution engine will stop digging into the inner
modules when it reaches a module that is decorated with ``basic_unit``.
.. code-block:: python
@basic_unit
class PrimitiveOp(nn.Module):
...
"""
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
if _check_wrapped(cls, 'basic_unit'):
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.' # type: ignore
cls = trace(cls)
cls._nni_basic_unit = basic_unit_tag # type: ignore
_torchscript_patch(cls)
return cls
[docs]def model_wrapper(cls: T) -> T:
"""
Wrap the base model (search space). For example,
.. code-block:: python
@model_wrapper
class MyModel(nn.Module):
...
The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in namespace so that the auto label counting in each model stably starts from zero.
Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
But in future, we might enforce ``@model_wrapper`` to be required for base model.
"""
# Internal flag. See nni.trace
nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
if nni_trace_flag.lower() == 'disable':
return cls
if _check_wrapped(cls, 'model_wrapper'):
return cls
import torch.nn as nn
assert issubclass(cls, nn.Module) # type: ignore
# subclass can still use trace info
wrapper = trace(cls, inheritable=True)
class reset_wrapper(wrapper):
def __init__(self, *args, **kwargs):
self._model_namespace = ModelNamespace()
with self._model_namespace:
super().__init__(*args, **kwargs)
_copy_class_wrapper_attributes(wrapper, reset_wrapper)
reset_wrapper.__wrapped__ = getattr(wrapper, '__wrapped__', wrapper)
reset_wrapper._nni_model_wrapper = True
reset_wrapper._traced = True
_torchscript_patch(cls)
return reset_wrapper
def is_basic_unit(cls_or_instance) -> bool:
if not inspect.isclass(cls_or_instance):
cls_or_instance = cls_or_instance.__class__
return getattr(cls_or_instance, '_nni_basic_unit', False)
def is_model_wrapped(cls_or_instance) -> bool:
if not inspect.isclass(cls_or_instance):
cls_or_instance = cls_or_instance.__class__
return getattr(cls_or_instance, '_nni_model_wrapper', False)
def _check_wrapped(cls: Type, rewrap: str) -> bool:
wrapped = None
if is_model_wrapped(cls):
wrapped = 'model_wrapper'
elif is_basic_unit(cls):
wrapped = 'basic_unit'
elif is_wrapped_with_trace(cls):
wrapped = 'nni.trace'
if wrapped:
if wrapped != rewrap:
raise TypeError(f'{cls} is already wrapped with {wrapped}. Cannot rewrap with {rewrap}.')
return True
return False
def _torchscript_patch(cls) -> None:
# HACK: for torch script
# https://github.com/pytorch/pytorch/pull/45261
# https://github.com/pytorch/pytorch/issues/54688
# I'm not sure whether there will be potential issues
import torch
if hasattr(cls, '_get_nni_attr'): # could not exist on non-linux
cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
if hasattr(cls, 'trace_symbol'):
# these must all exist or all non-exist
try:
cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
cls.trace_args = torch.jit.unused(cls.trace_args)
cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
cls.trace_copy = torch.jit.ignore(cls.trace_copy)
except AttributeError as e:
if 'property' in str(e):
raise RuntimeError('Trace on PyTorch module failed. Your PyTorch version might be outdated. '
'Please try to upgrade PyTorch.')
raise