nni.nas.hub.pytorch.shufflenet 源代码

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Union, cast

import torch
import torch.nn as nn
import nni
from nni.mutable import MutableExpression
from nni.nas.nn.pytorch import LayerChoice, ModelSpace, MutableConv2d, MutableBatchNorm2d

from .utils.pretrained import load_pretrained_weight


class ShuffleNetBlock(nn.Module):
    """
    Describe the basic building block of shuffle net, as described in
    `ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices <https://arxiv.org/pdf/1707.01083.pdf>`__.

    When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
    """

    def __init__(self, in_channels: int, out_channels: int, mid_channels: Union[int, MutableExpression[int]], *,
                 kernel_size: int, stride: int, sequence: str = "pdp", affine: bool = True):
        super().__init__()
        assert stride in [1, 2]
        assert kernel_size in [3, 5, 7]
        self.channels = in_channels // 2 if stride == 1 else in_channels
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mid_channels = mid_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.pad = kernel_size // 2
        self.oup_main = out_channels - self.channels
        self.affine = affine
        assert self.oup_main > 0

        self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))

        if stride == 2:
            self.branch_proj = nn.Sequential(
                # dw
                MutableConv2d(self.channels, self.channels, kernel_size, stride, self.pad,
                              groups=self.channels, bias=False),
                MutableBatchNorm2d(self.channels, affine=affine),
                # pw-linear
                MutableConv2d(self.channels, self.channels, 1, 1, 0, bias=False),
                MutableBatchNorm2d(self.channels, affine=affine),
                nn.ReLU(inplace=True)
            )
        else:
            # empty block to be compatible with torchscript
            self.branch_proj = nn.Sequential()

    def forward(self, x):
        if self.stride == 2:
            x_proj, x = self.branch_proj(x), x
        else:
            x_proj, x = self._channel_shuffle(x)
        return torch.cat((x_proj, self.branch_main(x)), 1)

    def _decode_point_depth_conv(self, sequence):
        result = []
        first_depth = first_point = True
        pc: int = self.channels
        c: int = self.channels
        for i, token in enumerate(sequence):
            # compute output channels of this conv
            if i + 1 == len(sequence):
                assert token == "p", "Last conv must be point-wise conv."
                c = self.oup_main
            elif token == "p" and first_point:
                c = cast(int, self.mid_channels)
            if token == "d":
                # depth-wise conv
                if isinstance(pc, int) and isinstance(c, int):
                    # check can only be done for static channels
                    assert pc == c, "Depth-wise conv must not change channels."
                result.append(MutableConv2d(pc, c, self.kernel_size, self.stride if first_depth else 1, self.pad,
                                            groups=c, bias=False))
                result.append(MutableBatchNorm2d(c, affine=self.affine))
                first_depth = False
            elif token == "p":
                # point-wise conv
                result.append(MutableConv2d(pc, c, 1, 1, 0, bias=False))
                result.append(MutableBatchNorm2d(c, affine=self.affine))
                result.append(nn.ReLU(inplace=True))
                first_point = False
            else:
                raise ValueError("Conv sequence must be d and p.")
            pc = c
        return result

    def _channel_shuffle(self, x):
        bs, num_channels, height, width = x.size()
        # NOTE: this line is commented for torchscript
        # assert (num_channels % 4 == 0)
        x = x.reshape(bs * num_channels // 2, 2, height * width)
        x = x.permute(1, 0, 2)
        x = x.reshape(2, -1, num_channels // 2, height, width)
        return x[0], x[1]


class ShuffleXceptionBlock(ShuffleNetBlock):
    """
    The ``choice_x`` version of shuffle net block, described in
    `Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
    """

    def __init__(self, in_channels: int, out_channels: int, mid_channels: Union[int, MutableExpression[int]],
                 *, stride: int, affine: bool = True):
        super().__init__(in_channels, out_channels, mid_channels,
                         kernel_size=3, stride=stride, sequence="dpdpdp", affine=affine)


[文档] class ShuffleNetSpace(ModelSpace): """ The search space proposed in `Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__. The basic building block design is inspired by a state-of-the-art manually-designed network -- `ShuffleNetV2 <https://openaccess.thecvf.com/content_ECCV_2018/html/Ningning_Light-weight_CNN_Architecture_ECCV_2018_paper.html>`__. There are 20 choice blocks in total. Each choice block has 4 candidates, namely ``choice 3``, ``choice 5``, ``choice_7`` and ``choice_x`` respectively. They differ in kernel sizes and the number of depthwise convolutions. The size of the search space is :math:`4^{20}`. Parameters ---------- num_labels : int Number of classes for the classification head. Default: 1000. channel_search : bool If true, for each building block, the number of ``mid_channels`` (output channels of the first 1x1 conv in each building block) varies from 0.2x to 1.6x (quantized to multiple of 0.2). Here, "k-x" means k times the number of default channels. Otherwise, 1.0x is used by default. Default: false. affine : bool Apply affine to all batch norm. Default: true. """ def __init__(self, num_labels: int = 1000, channel_search: bool = False, affine: bool = True): super().__init__() self.num_labels = num_labels self.channel_search = channel_search self.affine = affine # the block number in each stage. 4 stages in total. 20 blocks in total. self.stage_repeats = [4, 4, 8, 4] # output channels for all stages, including the very first layer and the very last layer self.stage_out_channels = [-1, 16, 64, 160, 320, 640, 1024] # building first layer out_channels = self.stage_out_channels[1] self.first_conv = nn.Sequential( nn.Conv2d(3, out_channels, 3, 2, 1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), ) feature_blocks = [] global_block_idx = 0 for stage_idx, num_repeat in enumerate(self.stage_repeats): for block_idx in range(num_repeat): # count global index to give names to choices global_block_idx += 1 # get ready for input and output in_channels = out_channels out_channels = self.stage_out_channels[stage_idx + 2] stride = 2 if block_idx == 0 else 1 # mid channels can be searched base_mid_channels = out_channels // 2 if self.channel_search: k_choice_list = [int(base_mid_channels * (.2 * k)) for k in range(1, 9)] mid_channels = nni.choice(f'channel_{global_block_idx}', k_choice_list) else: mid_channels = int(base_mid_channels) mid_channels = cast(Union[int, MutableExpression[int]], mid_channels) choice_block = LayerChoice(dict( k3=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=3, stride=stride, affine=affine), k5=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=5, stride=stride, affine=affine), k7=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=7, stride=stride, affine=affine), xcep=ShuffleXceptionBlock(in_channels, out_channels, mid_channels=mid_channels, stride=stride, affine=affine) ), label=f'layer_{global_block_idx}') feature_blocks.append(choice_block) self.features = nn.Sequential(*feature_blocks) # final layers last_conv_channels = self.stage_out_channels[-1] self.conv_last = nn.Sequential( nn.Conv2d(out_channels, last_conv_channels, 1, 1, 0, bias=False), nn.BatchNorm2d(last_conv_channels, affine=affine), nn.ReLU(inplace=True), ) self.globalpool = nn.AdaptiveAvgPool2d((1, 1)) self.dropout = nn.Dropout(0.1) self.classifier = nn.Sequential( nn.Linear(last_conv_channels, num_labels, bias=False), ) self._initialize_weights() def forward(self, x): x = self.first_conv(x) x = self.features(x) x = self.conv_last(x) x = self.globalpool(x) x = self.dropout(x) x = x.contiguous().view(-1, self.stage_out_channels[-1]) x = self.classifier(x) return x def _initialize_weights(self): for name, m in self.named_modules(): if isinstance(m, nn.Conv2d): if 'first' in name: torch.nn.init.normal_(m.weight, 0, 0.01) # type: ignore else: torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1]) # type: ignore if m.bias is not None: torch.nn.init.constant_(m.bias, 0) # type: ignore elif isinstance(m, nn.BatchNorm2d): if m.weight is not None: torch.nn.init.constant_(m.weight, 1) # type: ignore if m.bias is not None: torch.nn.init.constant_(m.bias, 0.0001) # type: ignore if m.running_mean is not None: torch.nn.init.constant_(m.running_mean, 0) # type: ignore elif isinstance(m, nn.BatchNorm1d): if m.weight is not None: torch.nn.init.constant_(m.weight, 1) # type: ignore if m.bias is not None: torch.nn.init.constant_(m.bias, 0.0001) # type: ignore if m.running_mean is not None: torch.nn.init.constant_(m.running_mean, 0) # type: ignore elif isinstance(m, nn.Linear): torch.nn.init.normal_(m.weight, 0, 0.01) # type: ignore if m.bias is not None: torch.nn.init.constant_(m.bias, 0) # type: ignore @classmethod def load_searched_model( cls, name: str, pretrained: bool = False, download: bool = False, progress: bool = True ) -> nn.Module: if name == 'spos': # NOTE: Need BGR tensor, with no normalization # https://github.com/ultmaster/spacehub-conversion/blob/371a4fd6646b4e11eda3f61187f7c9a1d484b1ca/cutils.py#L63 arch = { 'layer_1': 'k7', 'layer_2': 'k5', 'layer_3': 'k3', 'layer_4': 'k5', 'layer_5': 'k7', 'layer_6': 'k3', 'layer_7': 'k7', 'layer_8': 'k3', 'layer_9': 'k7', 'layer_10': 'k3', 'layer_11': 'k7', 'layer_12': 'xcep', 'layer_13': 'k3', 'layer_14': 'k3', 'layer_15': 'k3', 'layer_16': 'k3', 'layer_17': 'xcep', 'layer_18': 'k7', 'layer_19': 'xcep', 'layer_20': 'xcep' } else: raise ValueError(f'Unsupported architecture with name: {name}') model_factory = cls.frozen_factory(arch) model = model_factory() if pretrained: weight_file = load_pretrained_weight(name, download=download, progress=progress) pretrained_weights = torch.load(weight_file) model.load_state_dict(pretrained_weights) return model