# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import functools
import logging
import warnings
from typing import Any, Dict, Sequence, List, Tuple, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import preprocess_cell_inputs
from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy
from .sampling import PathSamplingCell
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, weighted_sum
_logger = logging.getLogger(__name__)
__all__ = [
'DifferentiableMixedLayer', 'DifferentiableMixedInput',
'DifferentiableMixedRepeat', 'DifferentiableMixedCell',
'MixedOpDifferentiablePolicy',
]
class GumbelSoftmax(nn.Softmax):
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
dim: int
def __init__(self, dim: int = -1) -> None:
super().__init__(dim)
self.tau = 1
self.hard = False
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return F.gumbel_softmax(inputs, tau=self.tau, hard=self.hard, dim=self.dim)
[docs]class DifferentiableMixedLayer(BaseSuperNetModule):
"""
Mixed layer, in which fprop is decided by a weighted sum of several layers.
Proposed in `DARTS: Differentiable Architecture Search <https://arxiv.org/abs/1806.09055>`__.
The weight ``alpha`` is usually learnable, and optimized on validation dataset.
Differentiable sampling layer requires all operators returning the same shape for one input,
as all outputs will be weighted summed to get the final output.
Parameters
----------
paths : list[tuple[str, nn.Module]]
Layers to choose from. Each is a tuple of name, and its module.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
op_names : str
Operator names.
label : str
Name of the choice.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self,
paths: list[tuple[str, nn.Module]],
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__()
self.op_names = []
if len(alpha) != len(paths):
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({len(paths)}).')
for name, module in paths:
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self.label = label
self._arch_alpha = alpha
self._softmax = softmax
[docs] def resample(self, memo):
"""Do nothing. Differentiable layer doesn't need resample."""
return {}
[docs] def export(self, memo):
"""Choose the operator with the maximum logit."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, LayerChoice):
size = len(module)
if module.label in memo:
alpha = memo[module.label]
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(list(module.named_children()), alpha, softmax, module.label)
[docs] def reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
[docs] def forward(self, *args, **kwargs):
"""The forward of mixed layer accepts same arguments as its sub-layer."""
all_op_results = [getattr(self, op)(*args, **kwargs) for op in self.op_names]
return self.reduction(all_op_results, self._softmax(self._arch_alpha))
[docs] def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
for _, p in self.named_parameters(*args, **kwargs):
yield p
[docs] def named_parameters(self, *args, **kwargs):
"""Named parameters excluding architecture parameters."""
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name == par_name for par_name in self._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
[docs]class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
"""Implementes the differentiable sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Thus the ``_arch_alpha`` here is a parameter dict, and ``named_parameters``
filters out multiple parameters with ``_arch_alpha`` as its prefix.
When this class is asked for ``forward_argument``, it returns a distribution,
i.e., a dict from int to float based on its weights.
All the parameters (``_arch_alpha``, ``parameters()``, ``_softmax``) are
saved as attributes of ``operation``, rather than ``self``,
because this class itself is not a ``nn.Module``, and saved parameters here
won't be optimized.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation._arch_alpha = nn.ParameterDict()
for name, spec in operation.search_space_spec().items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
operation._arch_alpha[name] = alpha
operation.parameters = functools.partial(self.parameters, module=operation) # bind self
operation.named_parameters = functools.partial(self.named_parameters, module=operation)
operation._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
@staticmethod
def parameters(module, *args, **kwargs):
for _, p in module.named_parameters(*args, **kwargs):
yield p
@staticmethod
def named_parameters(module, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super(module.__class__, module).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
[docs] def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Differentiable. Do nothing in resample."""
return {}
[docs] def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is argmax for each leaf value choice."""
result = {}
for name, spec in operation.search_space_spec().items():
if name in memo:
continue
chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item())
result[name] = spec.values[chosen_index]
return result
def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any:
if name in operation.mutable_arguments:
weights: dict[str, torch.Tensor] = {
label: cast(nn.Module, operation._softmax)(alpha) for label, alpha in cast(dict, operation._arch_alpha).items()
}
return dict(traverse_all_options(operation.mutable_arguments[name], weights=weights))
return operation.init_arguments[name]
[docs]class DifferentiableMixedRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths.
If the output is not a single tensor, it will be summed at every independant dimension.
See :func:`weighted_sum` for details.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self,
blocks: list[nn.Module],
depth: ChoiceOf[int],
softmax: nn.Module,
memo: dict[str, Any]):
super().__init__()
self.blocks = blocks
self.depth = depth
self._softmax = softmax
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._arch_alpha = nn.ParameterDict()
for name, spec in self._space_spec.items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
self._arch_alpha[name] = alpha
[docs] def resample(self, memo):
"""Do nothing."""
return {}
[docs] def export(self, memo):
"""Choose argmax for each leaf value choice."""
result = {}
for name, spec in self._space_spec.items():
if name in memo:
continue
chosen_index = int(torch.argmax(self._arch_alpha[name]).item())
result[name] = spec.values[chosen_index]
return result
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
# Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(cast(List[nn.Module], module.blocks), module.depth_choice, softmax, memo)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
[docs] def reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, x):
weights: dict[str, torch.Tensor] = {
label: self._softmax(alpha) for label, alpha in self._arch_alpha.items()
}
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights)))
res: list[torch.Tensor] = []
weight_list: list[float] = []
depths: list[int] = []
for i, block in enumerate(self.blocks, start=1): # start=1 because depths are 1, 2, 3, 4...
x = block(x)
if i in depth_weights:
weight_list.append(depth_weights[i])
res.append(x)
depths.append(i)
return self.reduction(res, weight_list, depths)
[docs]class DifferentiableMixedCell(PathSamplingCell):
"""Implementation of Cell under differentiable context.
An architecture parameter is created on each edge of the full-connected graph.
"""
# TODO: It inherits :class:`PathSamplingCell` to reduce some duplicated code.
# Possibly need another refactor here.
def __init__(
self, op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor, concat_dim,
memo, mutate_kwargs, label
):
super().__init__(
op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor,
concat_dim, memo, mutate_kwargs, label
)
self._arch_alpha = nn.ParameterDict()
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for j in range(i):
edge_label = f'{label}/{i}_{j}'
op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j]
if edge_label in memo:
alpha = memo[edge_label]
if len(alpha) != len(op):
raise ValueError(
f'Architecture parameter size of same label {edge_label} conflict: '
f'{len(alpha)} vs. {len(op)}'
)
else:
alpha = nn.Parameter(torch.randn(len(op)) * 1E-3)
self._arch_alpha[edge_label] = alpha
self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
[docs] def resample(self, memo):
"""Differentiable doesn't need to resample."""
return {}
[docs] def export(self, memo):
"""Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
"""
exported = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
# Tuple of (weight, input_index, op_name)
all_weights: list[tuple[float, int, str]] = []
for j in range(i):
for k, name in enumerate(self.op_names):
all_weights.append((
float(self._arch_alpha[f'{self.label}/{i}_{j}'][k].item()),
j, name,
))
all_weights.sort(reverse=True)
# We first prefer inputs from different input_index.
# If we have got no other choices, we start to accept duplicates.
# Therefore we gather first occurrences of distinct input_index to the front.
first_occurrence_index: list[int] = [
all_weights.index( # The index of
next(filter(lambda t: t[1] == j, all_weights)) # First occurence of j
)
for j in range(i) # For j < i
]
first_occurrence_index.sort() # Keep them ordered too.
all_weights = [all_weights[k] for k in first_occurrence_index] + \
[w for j, w in enumerate(all_weights) if j not in first_occurrence_index]
_logger.info('Sorted weights in differentiable cell export (node %d): %s', i, all_weights)
for k in range(self.num_ops_per_node):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
_, j, op_name = all_weights[k % len(all_weights)]
exported[f'{self.label}/op_{i}_{k}'] = op_name
exported[f'{self.label}/input_{i}_{k}'] = j
return exported
def forward(self, *inputs: list[torch.Tensor] | torch.Tensor) -> tuple[torch.Tensor, ...] | torch.Tensor:
processed_inputs: list[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: list[torch.Tensor] = self.preprocessor(processed_inputs)
for i, ops in enumerate(cast(Sequence[Sequence[Dict[str, nn.Module]]], self.ops), start=self.num_predecessors):
current_state = []
for j in range(i): # for every previous tensors
op_results = torch.stack([op(states[j]) for op in ops[j].values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0)
current_state.append(edge_sum)
states.append(sum(current_state)) # type: ignore
# Always merge all
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p