# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import warnings
from typing import Any, List, Union, Dict, Optional
import torch
import torch.nn as nn
from ...serializer import Translatable, basic_unit
from ...utils import NoContextError
from .utils import generate_new_label, get_fixed_value
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
[docs]class LayerChoice(nn.Module):
"""
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
Layer choice does not allow itself to be nested.
Parameters
----------
candidates : list of nn.Module or OrderedDict
A module list to be selected from.
prior : list of float
Prior distribution used in random sampling.
label : str
Identifier of the layer choice.
Attributes
----------
length : int
Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended.
names : list of str
Names of candidates.
choices : list of Module
Deprecated. A list of all candidate modules in the layer choice module.
``list(layer_choice)`` is recommended, which will serve the same purpose.
Notes
-----
``candidates`` can be a list of modules or a ordered dict of named modules, for example,
.. code-block:: python
self.op_choice = LayerChoice(OrderedDict([
("conv3x3", nn.Conv2d(3, 16, 128)),
("conv5x5", nn.Conv2d(5, 16, 128)),
("conv7x7", nn.Conv2d(7, 16, 128))
]))
Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
# FIXME: prior is designed but not supported yet
def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
try:
chosen = get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
else:
return candidates[chosen]
except NoContextError:
return super().__new__(cls)
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
super(LayerChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
label = kwargs['key']
if 'return_mask' in kwargs:
warnings.warn(f'"return_mask" is deprecated. Ignoring...')
if 'reduction' in kwargs:
warnings.warn(f'"reduction" is deprecated. Ignoring...')
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
self._label = generate_new_label(label)
self.names = []
if isinstance(candidates, dict):
for name, module in candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.add_module(name, module)
self.names.append(name)
elif isinstance(candidates, list):
for i, module in enumerate(candidates):
self.add_module(str(i), module)
self.names.append(str(i))
else:
raise TypeError("Unsupported candidates type: {}".format(type(candidates)))
self._first_module = self._modules[self.names[0]] # to make the dummy forward meaningful
@property
def key(self):
return self._key()
@torch.jit.ignore
def _key(self):
warnings.warn('Using key to access the identifier of LayerChoice is deprecated. Please use label instead.',
category=DeprecationWarning)
return self._label
@property
def label(self):
return self._label
def __getitem__(self, idx):
if isinstance(idx, str):
return self._modules[idx]
return list(self)[idx]
def __setitem__(self, idx, module):
key = idx if isinstance(idx, str) else self.names[idx]
return setattr(self, key, module)
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in self.names[idx]:
delattr(self, key)
else:
if isinstance(idx, str):
key, idx = idx, self.names.index(idx)
else:
key = self.names[idx]
delattr(self, key)
del self.names[idx]
def __len__(self):
return len(self.names)
def __iter__(self):
return map(lambda name: self._modules[name], self.names)
@property
def choices(self):
return self._choices()
@torch.jit.ignore
def _choices(self):
warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", category=DeprecationWarning)
return list(self)
[docs] def forward(self, x):
warnings.warn('You should not run forward of this module directly.')
return self._first_module(x)
def __repr__(self):
return f'LayerChoice({self.candidates}, label={repr(self.label)})'
[docs]class ValueChoice(Translatable, nn.Module):
"""
ValueChoice is to choose one from ``candidates``.
In most use scenarios, ValueChoice should be passed to the init parameters of a serializable module. For example,
.. code-block:: python
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, nn.ValueChoice([32, 64]), kernel_size=nn.ValueChoice([3, 5, 7]))
def forward(self, x):
return self.conv(x)
In case, you want to search a parameter that is used repeatedly, this is also possible by sharing the same value choice instance.
(Sharing the label should have the same effect.) For example,
.. code-block:: python
class Net(nn.Module):
def __init__(self):
super().__init__()
hidden_dim = nn.ValueChoice([128, 512])
self.fc = nn.Sequential(
nn.Linear(64, hidden_dim),
nn.Linear(hidden_dim, 10)
)
# the following code has the same effect.
# self.fc = nn.Sequential(
# nn.Linear(64, nn.ValueChoice([128, 512], label='dim')),
# nn.Linear(nn.ValueChoice([128, 512], label='dim'), 10)
# )
def forward(self, x):
return self.fc(x)
Note that ValueChoice should be used directly. Transformations like ``nn.Linear(32, nn.ValueChoice([64, 128]) * 2)``
are not supported.
Another common use case is to initialize the values to choose from in init and call the module in forward to get the chosen value.
Usually, this is used to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```.
For example,
.. code-block:: python
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([0., 1.])
def forward(self, x):
return F.dropout(x, self.dropout_rate())
Parameters
----------
candidates : list
List of values to choose from.
prior : list of float
Prior distribution to sample from.
label : str
Identifier of the value choice.
"""
# FIXME: prior is designed but not supported yet
def __new__(cls, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
try:
return get_fixed_value(label)
except NoContextError:
return super().__new__(cls)
def __init__(self, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__()
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
self._label = generate_new_label(label)
self._accessor = []
@property
def label(self):
return self._label
[docs] def forward(self):
warnings.warn('You should not run forward of this module directly.')
return self.candidates[0]
def _translate(self):
# Will function as a value when used in serializer.
return self.access(self.candidates[0])
def __repr__(self):
return f'ValueChoice({self.candidates}, label={repr(self.label)})'
def access(self, value):
if not self._accessor:
return value
try:
v = value
for a in self._accessor:
v = v[a]
except KeyError:
raise KeyError(''.join([f'[{a}]' for a in self._accessor]) + f' does not work on {value}')
return v
def __copy__(self):
return self
def __deepcopy__(self, memo):
new_item = ValueChoice(self.candidates, label=self.label)
new_item._accessor = [*self._accessor]
return new_item
def __getitem__(self, item):
"""
Get a sub-element of value choice.
The underlying implementation is to clone the current instance, and append item to "accessor", which records all
the history getitem calls. For example, when accessor is ``[a, b, c]``, the value choice will return ``vc[a][b][c]``
where ``vc`` is the original value choice.
"""
access = copy.deepcopy(self)
access._accessor.append(item)
for candidate in self.candidates:
access.access(candidate)
return access
@basic_unit
class Placeholder(nn.Module):
# TODO: docstring
def __init__(self, label, **related_info):
self.label = label
self.related_info = related_info
super().__init__()
def forward(self, x):
return x