Source code for nni.nas.pytorch.search_space_zoo.darts_cell

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from collections import OrderedDict

import torch
import torch.nn as nn
from nni.nas.pytorch import mutables

from .darts_ops import PoolBN, SepConv, DilConv, FactorizedReduce, DropPath, StdConv


class Node(nn.Module):
    def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
        """
        builtin Darts Node structure

        Parameters
        ---
        node_id: str
        num_prev_nodes: int
            the number of previous nodes in this cell
        channels: int
            output channels
        num_downsample_connect: int
            downsample the input node if this cell is reduction cell
        """
        super().__init__()
        self.ops = nn.ModuleList()
        choice_keys = []
        for i in range(num_prev_nodes):
            stride = 2 if i < num_downsample_connect else 1
            choice_keys.append("{}_p{}".format(node_id, i))
            self.ops.append(
                mutables.LayerChoice(OrderedDict([
                    ("maxpool", PoolBN('max', channels, 3, stride, 1, affine=False)),
                    ("avgpool", PoolBN('avg', channels, 3, stride, 1, affine=False)),
                    ("skipconnect",
                     nn.Identity() if stride == 1 else FactorizedReduce(channels, channels, affine=False)),
                    ("sepconv3x3", SepConv(channels, channels, 3, stride, 1, affine=False)),
                    ("sepconv5x5", SepConv(channels, channels, 5, stride, 2, affine=False)),
                    ("dilconv3x3", DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
                    ("dilconv5x5", DilConv(channels, channels, 5, stride, 4, 2, affine=False))
                ]), key=choice_keys[-1]))
        self.drop_path = DropPath()
        self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))

    def forward(self, prev_nodes):
        assert len(self.ops) == len(prev_nodes)
        out = [op(node) for op, node in zip(self.ops, prev_nodes)]
        out = [self.drop_path(o) if o is not None else None for o in out]
        return self.input_switch(out)


[docs]class DartsCell(nn.Module): """ Builtin Darts Cell structure. There are ``n_nodes`` nodes in one cell, in which the first two nodes' values are fixed to the results of previous previous cell and previous cell respectively. One node will connect all the nodes after with predefined operations in a mutable way. The last node accepts five inputs from nodes before and it concats all inputs in channels as the output of the current cell, and the number of output channels is ``n_nodes`` times ``channels``. Parameters --- n_nodes: int the number of nodes contained in this cell channels_pp: int the number of previous previous cell's output channels channels_p: int the number of previous cell's output channels channels: int the number of output channels for each node reduction_p: bool Is previous cell a reduction cell reduction: bool is current cell a reduction cell """ def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction): super().__init__() self.reduction = reduction self.n_nodes = n_nodes # If previous cell is reduction cell, current input size does not match with # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing. if reduction_p: self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False) else: self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False) self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False) # generate dag self.mutable_ops = nn.ModuleList() for depth in range(2, self.n_nodes + 2): self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth), depth, channels, 2 if reduction else 0))
[docs] def forward(self, pprev, prev): """ Parameters --- pprev: torch.Tensor the output of the previous previous layer prev: torch.Tensor the output of the previous layer """ tensors = [self.preproc0(pprev), self.preproc1(prev)] for node in self.mutable_ops: cur_tensor = node(tensors) tensors.append(cur_tensor) output = torch.cat(tensors[2:], dim=1) return output