# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Callable, Dict
import torch
import torch.nn as nn
from nni.nas.nn.pytorch import ModelSpace
from .modules.nasbench201 import NasBench201Cell
__all__ = ['NasBench201']
OPS_WITH_STRIDE = {
'none': lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
'avg_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'avg'),
'max_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'max'),
'conv_3x3': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3, 3), (stride, stride), (1, 1), (1, 1)),
'conv_1x1': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1, 1), (stride, stride), (0, 0), (1, 1)),
'skip_connect': lambda C_in, C_out, stride: nn.Identity() if stride == 1 and C_in == C_out
else FactorizedReduce(C_in, C_out, stride),
}
PRIMITIVES = ['none', 'skip_connect', 'conv_1x1', 'conv_3x3', 'avg_pool_3x3']
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out)
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out),
)
def forward(self, x):
return self.op(x)
class Pooling(nn.Module):
def __init__(self, C_in, C_out, stride, mode):
super(Pooling, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1)
if mode == 'avg':
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max':
self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else:
raise ValueError('Invalid mode={:} in Pooling'.format(mode))
def forward(self, x):
if self.preprocess:
x = self.preprocess(x)
return self.op(x)
class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
if self.C_in == self.C_out:
if self.stride == 1:
return x.mul(0.)
else:
return x[:, :, ::self.stride, ::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
self.conv_b = ReLUConvBN(planes, planes, 3, 1, 1, 1)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
inputs = self.downsample(inputs) # residual
return inputs + basicblock
[docs]
class NasBench201(ModelSpace):
"""The full search space proposed by `NAS-Bench-201 <https://arxiv.org/abs/2001.00326>`__.
It's a stack of :class:`~nni.nas.hub.pytorch.modules.NasBench201Cell`.
Parameters
----------
stem_out_channels
The output channels of the stem.
num_modules_per_stack
The number of modules (cells) in each stack. Each cell is a :class:`~nni.nas.hub.pytorch.modules.NasBench201Cell`.
num_labels
Number of categories for classification.
"""
def __init__(self,
stem_out_channels: int = 16,
num_modules_per_stack: int = 5,
num_labels: int = 10):
super().__init__()
self.channels = C = stem_out_channels
self.num_modules = N = num_modules_per_stack
self.num_labels = num_labels
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for C_curr, reduction in zip(layer_channels, layer_reductions):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
ops: Dict[str, Callable[[int, int], nn.Module]] = {
prim: self._make_op_factory(prim) for prim in PRIMITIVES
}
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
self.cells.append(cell)
C_prev = C_curr
self.lastact = nn.Sequential(
nn.BatchNorm2d(C_prev),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, self.num_labels)
def _make_op_factory(self, prim):
return lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1)
def forward(self, inputs):
feature = self.stem(inputs)
for cell in self.cells:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits