# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
class StdConv(nn.Module):
def __init__(self, C_in, C_out):
super(StdConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
[docs]class PoolBranch(nn.Module):
"""
Pooling structure for Macro search. First pass through a 1x1 Conv, then pooling operation followed by BatchNorm2d.
Parameters
---
pool_type: str
only accept ``max`` for MaxPool and ``avg`` for AvgPool
C_in: int
the number of input channels
C_out: int
the number of output channels
kernal_size: int
size of the convolving kernel
stride: int
stride of the convolution
padding: int
zero-padding added to both sides of the input
"""
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__()
self.preproc = StdConv(C_in, C_out)
self.pool = Pool(pool_type, kernel_size, stride, padding)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = self.preproc(x)
out = self.pool(out)
out = self.bn(out)
return out
class SeparableConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SeparableConv, self).__init__()
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
groups=C_in, bias=False)
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
[docs]class ConvBranch(nn.Module):
"""
Conv structure for Macro search. First pass through a 1x1 Conv,
then Conv operation with kernal_size equals 3 or 5 followed by BatchNorm and ReLU.
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
kernal_size: int
size of the convolving kernel
stride: int
stride of the convolution
padding: int
zero-padding added to both sides of the input
separable: True
is separable Conv is used
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
super(ConvBranch, self).__init__()
self.preproc = StdConv(C_in, C_out)
if separable:
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
else:
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
self.postproc = nn.Sequential(
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
out = self.preproc(x)
out = self.conv(out)
out = self.postproc(out)
return out
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=False):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
[docs]class Pool(nn.Module):
"""
Pooling structure
Parameters
---
pool_type: str
only accept ``max`` for MaxPool and ``avg`` for AvgPool
kernal_size: int
size of the convolving kernel
stride: int
stride of the convolution
padding: int
zero-padding added to both sides of the input
"""
def __init__(self, pool_type, kernel_size, stride, padding):
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
def forward(self, x):
return self.pool(x)
[docs]class SepConvBN(nn.Module):
"""
Implement SepConv followed by BatchNorm. The structure is ReLU ==> SepConv ==> BN.
Parameters
---
C_in: int
the number of imput channels
C_out: int
the number of output channels
kernal_size: int
size of the convolving kernel
padding: int
zero-padding added to both sides of the input
"""
def __init__(self, C_in, C_out, kernel_size, padding):
super().__init__()
self.relu = nn.ReLU()
self.conv = SeparableConv(C_in, C_out, kernel_size, 1, padding)
self.bn = nn.BatchNorm2d(C_out, affine=True)
def forward(self, x):
x = self.relu(x)
x = self.conv(x)
x = self.bn(x)
return x