Source code for nni.retiarii.nn.pytorch.component

import copy
from collections import OrderedDict
from typing import Callable, List, Union, Tuple, Optional

import torch
import torch.nn as nn

from .api import LayerChoice, InputChoice
from .nn import ModuleList

from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .utils import generate_new_label, get_fixed_value
from ...utils import NoContextError


__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell']


[docs]class Repeat(nn.Module): """ Repeat a block by a variable number of times. Parameters ---------- blocks : function, list of function, module or list of module The block to be repeated. If not a list, it will be replicated into a list. If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken. If a function, it will be called to instantiate a module. Otherwise the module will be deep-copied. depth : int or tuple of int If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max), meaning that the block will be repeated at least `min` times and at most `max` times. """ def __new__(cls, blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]], depth: Union[int, Tuple[int, int]], label: Optional[str] = None): try: repeat = get_fixed_value(label) return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat)) except NoContextError: return super().__new__(cls) def __init__(self, blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]], depth: Union[int, Tuple[int, int]], label: Optional[str] = None): super().__init__() self._label = generate_new_label(label) self.min_depth = depth if isinstance(depth, int) else depth[0] self.max_depth = depth if isinstance(depth, int) else depth[1] assert self.max_depth >= self.min_depth > 0 self.blocks = nn.ModuleList(self._replicate_and_instantiate(blocks, self.max_depth)) @property def label(self): return self._label
[docs] def forward(self, x): for block in self.blocks: x = block(x) return x
@staticmethod def _replicate_and_instantiate(blocks, repeat): if not isinstance(blocks, list): if isinstance(blocks, nn.Module): blocks = [blocks] + [copy.deepcopy(blocks) for _ in range(repeat - 1)] else: blocks = [blocks for _ in range(repeat)] assert len(blocks) > 0 assert repeat <= len(blocks), f'Not enough blocks to be used. {repeat} expected, only found {len(blocks)}.' blocks = blocks[:repeat] if not isinstance(blocks[0], nn.Module): blocks = [b() for b in blocks] return blocks
[docs]class Cell(nn.Module): """ Cell structure [zophnas]_ [zophnasnet]_ that is popularly used in NAS literature. A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from ``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell. The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes). Parameters ---------- op_candidates : function or list of module A list of modules to choose from, or a function that returns a list of modules. num_nodes : int Number of nodes in the cell. num_ops_per_node: int Number of operators in each node. The output of each node is the sum of all operators in the node. Default: 1. num_predecessors : int Number of inputs of the cell. The input to forward should be a list of tensors. Default: 1. merge_op : str Currently only ``all`` is supported, which has slight difference with that described in reference. Default: all. label : str Identifier of the cell. Cell sharing the same label will semantically share the same choice. References ---------- .. [zophnas] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578 .. [zophnasnet] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le, "Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012 """ # TODO: # Support loose end concat (shape inference on the following cells) # How to dynamically create convolution with stride as the first node def __init__(self, op_candidates: Union[Callable, List[nn.Module]], num_nodes: int, num_ops_per_node: int = 1, num_predecessors: int = 1, merge_op: str = 'all', label: str = None): super().__init__() self._label = generate_new_label(label) self.ops = ModuleList() self.inputs = ModuleList() self.num_nodes = num_nodes self.num_ops_per_node = num_ops_per_node self.num_predecessors = num_predecessors for i in range(num_nodes): self.ops.append(ModuleList()) self.inputs.append(ModuleList()) for k in range(num_ops_per_node): if isinstance(op_candidates, list): assert len(op_candidates) > 0 and isinstance(op_candidates[0], nn.Module) ops = copy.deepcopy(op_candidates) else: ops = op_candidates() self.ops[-1].append(LayerChoice(ops, label=f'{self.label}__op_{i}_{k}')) self.inputs[-1].append(InputChoice(i + num_predecessors, 1, label=f'{self.label}/input_{i}_{k}')) assert merge_op in ['all'] # TODO: loose_end self.merge_op = merge_op @property def label(self): return self._label
[docs] def forward(self, x: List[torch.Tensor]): states = x for ops, inps in zip(self.ops, self.inputs): current_state = [] for op, inp in zip(ops, inps): current_state.append(op(inp(states))) current_state = torch.sum(torch.stack(current_state), 0) states.append(current_state) return torch.cat(states[self.num_predecessors:], 1)
class NasBench201Cell(nn.Module): """ Cell structure that is proposed in NAS-Bench-201 [nasbench201]_ . This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor. For every i < j, there is an edge from i-th node to j-th node. Each edge in this DAG is associated with an operation transforming the hidden state from the source node to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``. Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension, and returns a ``Module``. Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example, The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates, and :math:`N` is defined by ``num_tensors``. Parameters ---------- op_candidates : list of callable Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module. in_features : int Input dimension of cell. out_features : int Output dimension of cell. num_tensors : int Number of tensors in the cell (input included). Default: 4 label : str Identifier of the cell. Cell sharing the same label will semantically share the same choice. References ---------- .. [nasbench201] Dong, X. and Yang, Y., 2020. Nas-bench-201: Extending the scope of reproducible neural architecture search. arXiv preprint arXiv:2001.00326. """ @staticmethod def _make_dict(x): if isinstance(x, list): return OrderedDict([(str(i), t) for i, t in enumerate(x)]) return OrderedDict(x) def __init__(self, op_candidates: List[Callable[[int, int], nn.Module]], in_features: int, out_features: int, num_tensors: int = 4, label: Optional[str] = None): super().__init__() self._label = generate_new_label(label) self.layers = nn.ModuleList() self.in_features = in_features self.out_features = out_features self.num_tensors = num_tensors op_candidates = self._make_dict(op_candidates) for tid in range(1, num_tensors): node_ops = nn.ModuleList() for j in range(tid): inp = in_features if j == 0 else out_features op_choices = OrderedDict([(key, cls(inp, out_features)) for key, cls in op_candidates.items()]) node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine self.layers.append(node_ops) def forward(self, inputs): tensors = [inputs] for layer in self.layers: current_tensor = [] for i, op in enumerate(layer): current_tensor.append(op(tensors[i])) current_tensor = torch.sum(torch.stack(current_tensor), 0) tensors.append(current_tensor) return tensors[-1]