Source code for nni.nas.hub.pytorch.nasbench201

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Callable, Dict

import torch
import torch.nn as nn

from nni.nas.nn.pytorch import ModelSpace
from .modules.nasbench201 import NasBench201Cell


__all__ = ['NasBench201']


OPS_WITH_STRIDE = {
    'none': lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
    'avg_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'avg'),
    'max_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'max'),
    'conv_3x3': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3, 3), (stride, stride), (1, 1), (1, 1)),
    'conv_1x1': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1, 1), (stride, stride), (0, 0), (1, 1)),
    'skip_connect': lambda C_in, C_out, stride: nn.Identity() if stride == 1 and C_in == C_out
    else FactorizedReduce(C_in, C_out, stride),
}

PRIMITIVES = ['none', 'skip_connect', 'conv_1x1', 'conv_3x3', 'avg_pool_3x3']


class ReLUConvBN(nn.Module):
    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
        super(ReLUConvBN, self).__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
                      padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(C_out)
        )

    def forward(self, x):
        return self.op(x)


class SepConv(nn.Module):
    def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
        super(SepConv, self).__init__()
        self.op = nn.Sequential(
            nn.ReLU(inplace=False),
            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
                      padding=padding, dilation=dilation, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(C_out),
        )

    def forward(self, x):
        return self.op(x)


class Pooling(nn.Module):
    def __init__(self, C_in, C_out, stride, mode):
        super(Pooling, self).__init__()
        if C_in == C_out:
            self.preprocess = None
        else:
            self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1)
        if mode == 'avg':
            self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
        elif mode == 'max':
            self.op = nn.MaxPool2d(3, stride=stride, padding=1)
        else:
            raise ValueError('Invalid mode={:} in Pooling'.format(mode))

    def forward(self, x):
        if self.preprocess:
            x = self.preprocess(x)
        return self.op(x)


class Zero(nn.Module):
    def __init__(self, C_in, C_out, stride):
        super(Zero, self).__init__()
        self.C_in = C_in
        self.C_out = C_out
        self.stride = stride
        self.is_zero = True

    def forward(self, x):
        if self.C_in == self.C_out:
            if self.stride == 1:
                return x.mul(0.)
            else:
                return x[:, :, ::self.stride, ::self.stride].mul(0.)
        else:
            shape = list(x.shape)
            shape[1] = self.C_out
            zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
            return zeros


class FactorizedReduce(nn.Module):
    def __init__(self, C_in, C_out, stride):
        super(FactorizedReduce, self).__init__()
        self.stride = stride
        self.C_in = C_in
        self.C_out = C_out
        self.relu = nn.ReLU(inplace=False)
        if stride == 2:
            C_outs = [C_out // 2, C_out - C_out // 2]
            self.convs = nn.ModuleList()
            for i in range(2):
                self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False))
            self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
        else:
            raise ValueError('Invalid stride : {:}'.format(stride))
        self.bn = nn.BatchNorm2d(C_out)

    def forward(self, x):
        x = self.relu(x)
        y = self.pad(x)
        out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
        out = self.bn(out)
        return out


class ResNetBasicblock(nn.Module):
    def __init__(self, inplanes, planes, stride):
        super(ResNetBasicblock, self).__init__()
        assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
        self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
        self.conv_b = ReLUConvBN(planes, planes, 3, 1, 1, 1)
        if stride == 2:
            self.downsample = nn.Sequential(
                nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
                nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
        elif inplanes != planes:
            self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
        else:
            self.downsample = None
        self.in_dim = inplanes
        self.out_dim = planes
        self.stride = stride
        self.num_conv = 2

    def forward(self, inputs):
        basicblock = self.conv_a(inputs)
        basicblock = self.conv_b(basicblock)

        if self.downsample is not None:
            inputs = self.downsample(inputs)  # residual
        return inputs + basicblock


[docs] class NasBench201(ModelSpace): """The full search space proposed by `NAS-Bench-201 <https://arxiv.org/abs/2001.00326>`__. It's a stack of :class:`~nni.nas.hub.pytorch.modules.NasBench201Cell`. Parameters ---------- stem_out_channels The output channels of the stem. num_modules_per_stack The number of modules (cells) in each stack. Each cell is a :class:`~nni.nas.hub.pytorch.modules.NasBench201Cell`. num_labels Number of categories for classification. """ def __init__(self, stem_out_channels: int = 16, num_modules_per_stack: int = 5, num_labels: int = 10): super().__init__() self.channels = C = stem_out_channels self.num_modules = N = num_modules_per_stack self.num_labels = num_labels self.stem = nn.Sequential( nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(C) ) layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N C_prev = C self.cells = nn.ModuleList() for C_curr, reduction in zip(layer_channels, layer_reductions): if reduction: cell = ResNetBasicblock(C_prev, C_curr, 2) else: ops: Dict[str, Callable[[int, int], nn.Module]] = { prim: self._make_op_factory(prim) for prim in PRIMITIVES } cell = NasBench201Cell(ops, C_prev, C_curr, label='cell') self.cells.append(cell) C_prev = C_curr self.lastact = nn.Sequential( nn.BatchNorm2d(C_prev), nn.ReLU(inplace=True) ) self.global_pooling = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Linear(C_prev, self.num_labels) def _make_op_factory(self, prim): return lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) def forward(self, inputs): feature = self.stem(inputs) for cell in self.cells: feature = cell(feature) out = self.lastact(feature) out = self.global_pooling(out) out = out.view(out.size(0), -1) logits = self.classifier(out) return logits