# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import warnings
from collections import OrderedDict
import torch.nn as nn
from nni.nas.pytorch.utils import global_mutable_counting
logger = logging.getLogger(__name__)
[docs]class Mutable(nn.Module):
"""
Mutable is designed to function as a normal layer, with all necessary operators' weights.
States and weights of architectures should be included in mutator, instead of the layer itself.
Mutable has a key, which marks the identity of the mutable. This key can be used by users to share
decisions among different mutables. In mutator's implementation, mutators should use the key to
distinguish different mutables. Mutables that share the same key should be "similar" to each other.
Currently the default scope for keys is global. By default, the keys uses a global counter from 1 to
produce unique ids.
Parameters
----------
key : str
The key of mutable.
Notes
-----
The counter is program level, but mutables are model level. In case multiple models are defined, and
you want to have `counter` starting from 1 in the second model, it's recommended to assign keys manually
instead of using automatic keys.
"""
def __init__(self, key=None):
super().__init__()
if key is not None:
if not isinstance(key, str):
key = str(key)
logger.warning("Warning: key \"%s\" is not string, converted to string.", key)
self._key = key
else:
self._key = self.__class__.__name__ + str(global_mutable_counting())
self.init_hook = self.forward_hook = None
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator):
if "mutator" in self.__dict__:
raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? "
"Or did you apply multiple fixed architectures?")
self.__dict__["mutator"] = mutator
@property
def key(self):
"""
Read-only property of key.
"""
return self._key
@property
def name(self):
"""
After the search space is parsed, it will be the module name of the mutable.
"""
return self._name if hasattr(self, "_name") else "_key"
@name.setter
def name(self, name):
self._name = name
def _check_built(self):
if not hasattr(self, "mutator"):
raise ValueError(
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
[docs]class MutableScope(Mutable):
"""
Mutable scope marks a subgraph/submodule to help mutators make better decisions.
If not annotated with mutable scope, search space will be flattened as a list. However, some mutators might
need to leverage the concept of a "cell". So if a module is defined as a mutable scope, everything in it will
look like "sub-search-space" in the scope. Scopes can be nested.
There are two ways mutators can use mutable scope. One is to traverse the search space as a tree during initialization
and reset. The other is to implement `enter_mutable_scope` and `exit_mutable_scope`. They are called before and after
the forward method of the class inheriting mutable scope.
Mutable scopes are also mutables that are listed in the mutator.mutables (search space), but they are not supposed
to appear in the dict of choices.
Parameters
----------
key : str
Key of mutable scope.
"""
def __init__(self, key):
super().__init__(key=key)
def __call__(self, *args, **kwargs):
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
[docs]class LayerChoice(Mutable):
"""
Layer choice selects one of the ``op_candidates``, then apply it on inputs and return results.
In rare cases, it can also select zero or many.
Layer choice does not allow itself to be nested.
Parameters
----------
op_candidates : list of nn.Module or OrderedDict
A module list to be selected from.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected.
If ``none``, a list is returned. ``mean`` returns the average. ``sum`` returns the sum.
``concat`` concatenate the list at dimension 1.
return_mask : bool
If ``return_mask``, return output tensor and a mask. Otherwise return tensor only.
key : str
Key of the input choice.
Attributes
----------
length : int
Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended.
names : list of str
Names of candidates.
choices : list of Module
Deprecated. A list of all candidate modules in the layer choice module.
``list(layer_choice)`` is recommended, which will serve the same purpose.
Notes
-----
``op_candidates`` can be a list of modules or a ordered dict of named modules, for example,
.. code-block:: python
self.op_choice = LayerChoice(OrderedDict([
("conv3x3", nn.Conv2d(3, 16, 128)),
("conv5x5", nn.Conv2d(5, 16, 128)),
("conv7x7", nn.Conv2d(7, 16, 128))
]))
Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key)
self.names = []
if isinstance(op_candidates, OrderedDict):
for name, module in op_candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.add_module(name, module)
self.names.append(name)
elif isinstance(op_candidates, list):
for i, module in enumerate(op_candidates):
self.add_module(str(i), module)
self.names.append(str(i))
else:
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates)))
self.reduction = reduction
self.return_mask = return_mask
def __getitem__(self, idx):
if isinstance(idx, str):
return self._modules[idx]
return list(self)[idx]
def __setitem__(self, idx, module):
key = idx if isinstance(idx, str) else self.names[idx]
return setattr(self, key, module)
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in self.names[idx]:
delattr(self, key)
else:
if isinstance(idx, str):
key, idx = idx, self.names.index(idx)
else:
key = self.names[idx]
delattr(self, key)
del self.names[idx]
@property
def length(self):
warnings.warn("layer_choice.length is deprecated. Use `len(layer_choice)` instead.", DeprecationWarning)
return len(self)
def __len__(self):
return len(self.names)
def __iter__(self):
return map(lambda name: self._modules[name], self.names)
@property
def choices(self):
warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", DeprecationWarning)
return list(self)
[docs] def forward(self, *args, **kwargs):
"""
Returns
-------
tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
"""
out, mask = self.mutator.on_forward_layer_choice(self, *args, **kwargs)
if self.return_mask:
return out, mask
return out