# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Can't have annotations on because PyTorch JIT doesn't support it.
# from __future__ import annotations
import functools
import warnings
from typing import (Any, Iterator, List, Optional, Dict, Union, Tuple, cast)
from typing_extensions import Literal
import torch
import torch.nn as nn
from nni.mutable import Categorical, CategoricalMultiple, Sample, SampleValidationError, ensure_frozen
from .base import MutableModule, recursive_freeze
__all__ = [
# APIs
'LayerChoice',
'InputChoice',
'ValueChoice',
'MutationAnchor',
# Fixed module
'ChosenInputs',
# Type utils
'ReductionType',
]
class ValueChoice(Categorical):
"""For backward compatibility only. Please use :class:`nni.mutable.Categorical` instead."""
@functools.wraps(Categorical.__init__)
def __init__(self, *args, **kwargs):
warnings.warn('ValueChoice is deprecated, please use `nni.choice` instead', DeprecationWarning)
super().__init__(*args, **kwargs)
@property
def candidates(self) -> list:
return self.values
[docs]
class LayerChoice(MutableModule):
"""
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
It allows users to put several candidate operations (e.g., PyTorch modules), one of them is chosen in each explored model.
*New in v2.2:* Layer choice can be nested.
Parameters
----------
candidates : list of nn.Module or OrderedDict
A module list to be selected from.
weights : list of float
Prior distribution used in random sampling.
label : str
Identifier of the layer choice.
Attributes
----------
length : int
Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended.
names : list of str
Names of candidates.
choices : list of Module
Deprecated. A list of all candidate modules in the layer choice module.
``list(layer_choice)`` is recommended, which will serve the same purpose.
Examples
--------
::
# import nni.nas.nn.pytorch as nn
# declared in `__init__` method
self.layer = nn.LayerChoice([
ops.PoolBN('max', channels, 3, stride, 1),
ops.SepConv(channels, channels, 3, stride, 1),
nn.Identity()
])
# invoked in `forward` method
out = self.layer(x)
Notes
-----
``candidates`` can be a list of modules or a ordered dict of named modules, for example,
.. code-block:: python
self.op_choice = LayerChoice(OrderedDict([
("conv3x3", nn.Conv2d(3, 16, 128)),
("conv5x5", nn.Conv2d(5, 16, 128)),
("conv7x7", nn.Conv2d(7, 16, 128))
]))
Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
weights: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__()
_names, _modules = self._init_names(candidates)
for name, module in zip(_names, _modules):
self.add_module(str(name), module)
self.choice = self._inner_choice(_names, weights=weights, label=label)
self.add_mutable(self.choice)
self._dry_run_choice = ensure_frozen(self.choice)
# Names are kept as original types. They need to be converted to str for getattr.
self.names: Union[List[str], List[int]] = _names
@torch.jit.unused
@property
def label(self) -> str:
return self.choice.label
@torch.jit.unused
@property
def candidates(self) -> Union[Dict[str, nn.Module], List[nn.Module]]:
"""Restore the ``candidates`` parameters passed to the constructor.
Useful when creating a new layer choices based on this one.
"""
if all(isinstance(name, int) for name in self.names) and self.names == list(range(len(self))):
return list(self)
else:
return {cast(str, name): self[name] for name in self.names}
@staticmethod
def _inner_choice(names: List[str], weights: Optional[List[float]], label: Optional[str]) -> Categorical:
return Categorical(names, weights=weights, label=label)
@staticmethod
def _init_names(candidates: Union[Dict[str, nn.Module], List[nn.Module]]) -> Tuple[List[str], List[nn.Module]]:
names, modules = [], []
if isinstance(candidates, dict):
for name, module in candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
if not isinstance(name, str):
raise TypeError(f'Key of candidates must be str, got {type(name)}.')
names.append(name)
modules.append(module)
elif isinstance(candidates, list):
for i, module in enumerate(candidates):
names.append(i)
modules.append(module)
else:
raise TypeError("Unsupported candidates type: {}".format(type(candidates)))
return names, modules
def check_contains(self, sample: Sample) -> Optional[SampleValidationError]:
exception = self.choice.check_contains(sample)
if exception is not None:
return exception
sample_val = self.choice.freeze(sample)
module = self[sample_val]
if isinstance(module, MutableModule):
exception = module.check_contains(sample)
if exception is not None:
exception.paths.append(sample_val)
return exception
else:
for name, submodule in MutableModule.named_mutable_descendants(module): # type: ignore
exception = submodule.check_contains(sample)
if exception is not None:
exception.paths.append(name)
exception.paths.append(sample_val)
return exception
return None
def freeze(self, sample: Sample) -> nn.Module:
self.validate(sample)
sample_val = self.choice.freeze(sample)
return recursive_freeze(self[sample_val], sample)[0]
@classmethod
def create_fixed_module(cls, sample: dict, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
weights: Optional[List[float]] = None, label: Optional[str] = None):
names, _ = cls._init_names(candidates)
chosen = cls._inner_choice(names, weights, label).freeze(sample)
if isinstance(candidates, list):
result = candidates[int(chosen)]
else:
result = candidates[chosen]
return result
def __getitem__(self, idx: Union[int, str]) -> nn.Module:
if idx not in self.names:
raise KeyError(f'{idx!r} is not found in {self.names!r}.')
return cast(nn.Module, self._modules[str(idx)])
def __setitem__(self, idx, module):
if idx not in self.names:
raise KeyError(f'{idx!r} is not found in {self.names!r}. Note we disallow adding new choices to LayerChoice.')
return setattr(self, str(idx), module)
def __delitem__(self, idx):
raise RuntimeError('Deleting choices from LayerChoice is not supported yet.')
def __len__(self):
return len(self.names)
def __iter__(self) -> Iterator[nn.Module]:
return map(lambda name: cast(nn.Module, self._modules[str(name)]), self.names)
def forward(self, x):
# The input argument can be arbitrary positional / keyword arguments,
# but JIT is unhappy with the unrestricted cases.
# The forward of layer choice is simply running the first candidate module.
# It shouldn't be called directly by users in most cases.
for name, child_module in self.named_children():
# Explicitly cast str here to make JIT happy
if str(name) == str(self._dry_run_choice):
return child_module(x)
raise RuntimeError('dry_run_choice is not available. This should not happen.')
def extra_repr(self):
return f'label={self.label!r}'
ReductionType = Literal['mean', 'concat', 'sum', 'none']
class ChosenInputs(nn.Module):
"""
A module that chooses from a tensor list and outputs a reduced tensor.
The already-chosen version of InputChoice.
When forward, ``chosen`` will be used to select inputs from ``candidate_inputs``,
and ``reduction`` will be used to choose from those inputs to form a tensor.
Attributes
----------
chosen : list of int
Indices of chosen inputs.
reduction : ``mean`` | ``concat`` | ``sum`` | ``none``
How to reduce the inputs when multiple are selected.
"""
def __init__(self, chosen: Union[List[int], int], reduction: ReductionType):
super().__init__()
self.chosen = chosen if isinstance(chosen, list) else [chosen]
self.reduction = reduction
def forward(self, candidate_inputs: List[torch.Tensor]) -> Optional[torch.Tensor]:
"""
Compute the reduced input based on ``chosen`` and ``reduction``.
"""
return self._tensor_reduction(self.reduction, [candidate_inputs[i] for i in self.chosen]) # type: ignore
@staticmethod
def _tensor_reduction(reduction_type: str, tensor_list: List[torch.Tensor]) -> Union[List[torch.Tensor], torch.Tensor, None]:
if reduction_type == 'none':
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == 'sum':
return cast(torch.Tensor, sum(tensor_list))
if reduction_type == 'mean':
return cast(torch.Tensor, sum(tensor_list) / len(tensor_list))
if reduction_type == 'concat':
return torch.cat(tensor_list, dim=1)
raise ValueError(f'Unrecognized reduction policy: "{reduction_type}"')
class MutationAnchor(MutableModule):
"""
The API that creates an empty module for later mutations.
For advanced usages only.
"""
def __init__(self, *, label: str, **kwargs):
super().__init__()
self.label = label
self.kwargs = kwargs
def forward(self, x):
"""
Forward of placeholder is not meaningful.
It returns input directly.
"""
return x