Source code for nni.nas.pytorch.nasbench201.nasbench201

from collections import OrderedDict
import torch.nn as nn
from nni.nas.pytorch.mutables import LayerChoice

from .nasbench201_ops import Pooling, ReLUConvBN, Zero, FactorizedReduce


[docs]class NASBench201Cell(nn.Module): """ Builtin cell structure of NAS Bench 201. One cell contains four nodes. The First node serves as an input node accepting the output of the previous cell. And other nodes connect to all previous nodes with an edge that represents an operation chosen from a set to transform the tensor from the source node to the target node. Every node accepts all its inputs and adds them as its output. Parameters --- cell_id: str the name of this cell C_in: int the number of input channels of the cell C_out: int the number of output channels of the cell stride: int stride of all convolution operations in the cell bn_affine: bool If set to ``True``, all ``torch.nn.BatchNorm2d`` in this cell will have learnable affine parameters. Default: True bn_momentum: float the value used for the running_mean and running_var computation. Default: 0.1 bn_track_running_stats: bool When set to ``True``, all ``torch.nn.BatchNorm2d`` in this cell tracks the running mean and variance. Default: True """ def __init__(self, cell_id, C_in, C_out, stride, bn_affine=True, bn_momentum=0.1, bn_track_running_stats=True): super(NASBench201Cell, self).__init__() self.NUM_NODES = 4 self.layers = nn.ModuleList() OPS = lambda layer_idx: OrderedDict([ ("none", Zero(C_in, C_out, stride)), ("avg_pool_3x3", Pooling(C_in, C_out, stride if layer_idx == 0 else 1, bn_affine, bn_momentum, bn_track_running_stats)), ("conv_3x3", ReLUConvBN(C_in, C_out, 3, stride if layer_idx == 0 else 1, 1, 1, bn_affine, bn_momentum, bn_track_running_stats)), ("conv_1x1", ReLUConvBN(C_in, C_out, 1, stride if layer_idx == 0 else 1, 0, 1, bn_affine, bn_momentum, bn_track_running_stats)), ("skip_connect", nn.Identity() if stride == 1 and C_in == C_out else FactorizedReduce(C_in, C_out, stride if layer_idx == 0 else 1, bn_affine, bn_momentum, bn_track_running_stats)) ]) for i in range(self.NUM_NODES): node_ops = nn.ModuleList() for j in range(0, i): node_ops.append(LayerChoice(OPS(j), key="%d_%d" % (j, i), reduction="mean")) self.layers.append(node_ops) self.in_dim = C_in self.out_dim = C_out self.cell_id = cell_id
[docs] def forward(self, input): # pylint: disable=W0622 """ Parameters --- input: torch.tensor the output of the previous layer """ nodes = [input] for i in range(1, self.NUM_NODES): node_feature = sum(self.layers[i][k](nodes[k]) for k in range(i)) nodes.append(node_feature) return nodes[-1]