import logging
from typing import Any, List
import torch
import torch.nn as nn
from ...utils import add_record, blackbox_module, uid, version_larger_equal
_logger = logging.getLogger(__name__)
# NOTE: support pytorch version >= 1.5.0
__all__ = [
'LayerChoice', 'InputChoice', 'Placeholder',
'Module', 'Sequential', 'ModuleList', # TODO: 'ModuleDict', 'ParameterList', 'ParameterDict',
'Identity', 'Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'ConvTranspose1d',
'ConvTranspose2d', 'ConvTranspose3d', 'Threshold', 'ReLU', 'Hardtanh', 'ReLU6',
'Sigmoid', 'Tanh', 'Softmax', 'Softmax2d', 'LogSoftmax', 'ELU', 'SELU', 'CELU', 'GLU', 'GELU', 'Hardshrink',
'LeakyReLU', 'LogSigmoid', 'Softplus', 'Softshrink', 'MultiheadAttention', 'PReLU', 'Softsign', 'Softmin',
'Tanhshrink', 'RReLU', 'AvgPool1d', 'AvgPool2d', 'AvgPool3d', 'MaxPool1d', 'MaxPool2d',
'MaxPool3d', 'MaxUnpool1d', 'MaxUnpool2d', 'MaxUnpool3d', 'FractionalMaxPool2d', "FractionalMaxPool3d",
'LPPool1d', 'LPPool2d', 'LocalResponseNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d', 'InstanceNorm1d',
'InstanceNorm2d', 'InstanceNorm3d', 'LayerNorm', 'GroupNorm', 'SyncBatchNorm',
'Dropout', 'Dropout2d', 'Dropout3d', 'AlphaDropout', 'FeatureAlphaDropout',
'ReflectionPad1d', 'ReflectionPad2d', 'ReplicationPad2d', 'ReplicationPad1d', 'ReplicationPad3d',
'CrossMapLRN2d', 'Embedding', 'EmbeddingBag', 'RNNBase', 'RNN', 'LSTM', 'GRU', 'RNNCellBase', 'RNNCell',
'LSTMCell', 'GRUCell', 'PixelShuffle', 'Upsample', 'UpsamplingNearest2d', 'UpsamplingBilinear2d',
'PairwiseDistance', 'AdaptiveMaxPool1d', 'AdaptiveMaxPool2d', 'AdaptiveMaxPool3d', 'AdaptiveAvgPool1d',
'AdaptiveAvgPool2d', 'AdaptiveAvgPool3d', 'TripletMarginLoss', 'ZeroPad2d', 'ConstantPad1d', 'ConstantPad2d',
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
'Flatten', 'Hardsigmoid'
]
if version_larger_equal(torch.__version__, '1.6.0'):
__all__.append('Hardswish')
if version_larger_equal(torch.__version__, '1.7.0'):
__all__.extend(['Unflatten', 'SiLU', 'TripletMarginWithDistanceLoss'])
[docs]class LayerChoice(nn.Module):
def __init__(self, op_candidates, reduction=None, return_mask=False, key=None):
super(LayerChoice, self).__init__()
self.op_candidates = op_candidates
self.label = key if key is not None else f'layerchoice_{uid()}'
self.key = self.label # deprecated, for backward compatibility
for i, module in enumerate(op_candidates): # deprecated, for backward compatibility
self.add_module(str(i), module)
if reduction or return_mask:
_logger.warning('input arguments `reduction` and `return_mask` are deprecated!')
[docs] def forward(self, x):
return x
class ValueChoice:
"""
The instance of this class can only be used as input argument,
when instantiating a pytorch module.
TODO: can also be used in training approach
"""
def __init__(self, candidate_values: List[Any]):
self.candidate_values = candidate_values
class Placeholder(nn.Module):
def __init__(self, label, related_info):
add_record(id(self), related_info)
self.label = label
self.related_info = related_info
super(Placeholder, self).__init__()
def forward(self, x):
return x
class ChosenInputs(nn.Module):
"""
"""
def __init__(self, chosen: List[int], reduction: str):
super().__init__()
self.chosen = chosen
self.reduction = reduction
def forward(self, candidate_inputs):
return self._tensor_reduction(self.reduction, [candidate_inputs[i] for i in self.chosen])
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
# the following are pytorch modules
Module = nn.Module
class Sequential(nn.Sequential):
def __init__(self, *args):
add_record(id(self), {})
super(Sequential, self).__init__(*args)
class ModuleList(nn.ModuleList):
def __init__(self, *args):
add_record(id(self), {})
super(ModuleList, self).__init__(*args)
Identity = blackbox_module(nn.Identity)
Linear = blackbox_module(nn.Linear)
Conv1d = blackbox_module(nn.Conv1d)
Conv2d = blackbox_module(nn.Conv2d)
Conv3d = blackbox_module(nn.Conv3d)
ConvTranspose1d = blackbox_module(nn.ConvTranspose1d)
ConvTranspose2d = blackbox_module(nn.ConvTranspose2d)
ConvTranspose3d = blackbox_module(nn.ConvTranspose3d)
Threshold = blackbox_module(nn.Threshold)
ReLU = blackbox_module(nn.ReLU)
Hardtanh = blackbox_module(nn.Hardtanh)
ReLU6 = blackbox_module(nn.ReLU6)
Sigmoid = blackbox_module(nn.Sigmoid)
Tanh = blackbox_module(nn.Tanh)
Softmax = blackbox_module(nn.Softmax)
Softmax2d = blackbox_module(nn.Softmax2d)
LogSoftmax = blackbox_module(nn.LogSoftmax)
ELU = blackbox_module(nn.ELU)
SELU = blackbox_module(nn.SELU)
CELU = blackbox_module(nn.CELU)
GLU = blackbox_module(nn.GLU)
GELU = blackbox_module(nn.GELU)
Hardshrink = blackbox_module(nn.Hardshrink)
LeakyReLU = blackbox_module(nn.LeakyReLU)
LogSigmoid = blackbox_module(nn.LogSigmoid)
Softplus = blackbox_module(nn.Softplus)
Softshrink = blackbox_module(nn.Softshrink)
MultiheadAttention = blackbox_module(nn.MultiheadAttention)
PReLU = blackbox_module(nn.PReLU)
Softsign = blackbox_module(nn.Softsign)
Softmin = blackbox_module(nn.Softmin)
Tanhshrink = blackbox_module(nn.Tanhshrink)
RReLU = blackbox_module(nn.RReLU)
AvgPool1d = blackbox_module(nn.AvgPool1d)
AvgPool2d = blackbox_module(nn.AvgPool2d)
AvgPool3d = blackbox_module(nn.AvgPool3d)
MaxPool1d = blackbox_module(nn.MaxPool1d)
MaxPool2d = blackbox_module(nn.MaxPool2d)
MaxPool3d = blackbox_module(nn.MaxPool3d)
MaxUnpool1d = blackbox_module(nn.MaxUnpool1d)
MaxUnpool2d = blackbox_module(nn.MaxUnpool2d)
MaxUnpool3d = blackbox_module(nn.MaxUnpool3d)
FractionalMaxPool2d = blackbox_module(nn.FractionalMaxPool2d)
FractionalMaxPool3d = blackbox_module(nn.FractionalMaxPool3d)
LPPool1d = blackbox_module(nn.LPPool1d)
LPPool2d = blackbox_module(nn.LPPool2d)
LocalResponseNorm = blackbox_module(nn.LocalResponseNorm)
BatchNorm1d = blackbox_module(nn.BatchNorm1d)
BatchNorm2d = blackbox_module(nn.BatchNorm2d)
BatchNorm3d = blackbox_module(nn.BatchNorm3d)
InstanceNorm1d = blackbox_module(nn.InstanceNorm1d)
InstanceNorm2d = blackbox_module(nn.InstanceNorm2d)
InstanceNorm3d = blackbox_module(nn.InstanceNorm3d)
LayerNorm = blackbox_module(nn.LayerNorm)
GroupNorm = blackbox_module(nn.GroupNorm)
SyncBatchNorm = blackbox_module(nn.SyncBatchNorm)
Dropout = blackbox_module(nn.Dropout)
Dropout2d = blackbox_module(nn.Dropout2d)
Dropout3d = blackbox_module(nn.Dropout3d)
AlphaDropout = blackbox_module(nn.AlphaDropout)
FeatureAlphaDropout = blackbox_module(nn.FeatureAlphaDropout)
ReflectionPad1d = blackbox_module(nn.ReflectionPad1d)
ReflectionPad2d = blackbox_module(nn.ReflectionPad2d)
ReplicationPad2d = blackbox_module(nn.ReplicationPad2d)
ReplicationPad1d = blackbox_module(nn.ReplicationPad1d)
ReplicationPad3d = blackbox_module(nn.ReplicationPad3d)
CrossMapLRN2d = blackbox_module(nn.CrossMapLRN2d)
Embedding = blackbox_module(nn.Embedding)
EmbeddingBag = blackbox_module(nn.EmbeddingBag)
RNNBase = blackbox_module(nn.RNNBase)
RNN = blackbox_module(nn.RNN)
LSTM = blackbox_module(nn.LSTM)
GRU = blackbox_module(nn.GRU)
RNNCellBase = blackbox_module(nn.RNNCellBase)
RNNCell = blackbox_module(nn.RNNCell)
LSTMCell = blackbox_module(nn.LSTMCell)
GRUCell = blackbox_module(nn.GRUCell)
PixelShuffle = blackbox_module(nn.PixelShuffle)
Upsample = blackbox_module(nn.Upsample)
UpsamplingNearest2d = blackbox_module(nn.UpsamplingNearest2d)
UpsamplingBilinear2d = blackbox_module(nn.UpsamplingBilinear2d)
PairwiseDistance = blackbox_module(nn.PairwiseDistance)
AdaptiveMaxPool1d = blackbox_module(nn.AdaptiveMaxPool1d)
AdaptiveMaxPool2d = blackbox_module(nn.AdaptiveMaxPool2d)
AdaptiveMaxPool3d = blackbox_module(nn.AdaptiveMaxPool3d)
AdaptiveAvgPool1d = blackbox_module(nn.AdaptiveAvgPool1d)
AdaptiveAvgPool2d = blackbox_module(nn.AdaptiveAvgPool2d)
AdaptiveAvgPool3d = blackbox_module(nn.AdaptiveAvgPool3d)
TripletMarginLoss = blackbox_module(nn.TripletMarginLoss)
ZeroPad2d = blackbox_module(nn.ZeroPad2d)
ConstantPad1d = blackbox_module(nn.ConstantPad1d)
ConstantPad2d = blackbox_module(nn.ConstantPad2d)
ConstantPad3d = blackbox_module(nn.ConstantPad3d)
Bilinear = blackbox_module(nn.Bilinear)
CosineSimilarity = blackbox_module(nn.CosineSimilarity)
Unfold = blackbox_module(nn.Unfold)
Fold = blackbox_module(nn.Fold)
AdaptiveLogSoftmaxWithLoss = blackbox_module(nn.AdaptiveLogSoftmaxWithLoss)
TransformerEncoder = blackbox_module(nn.TransformerEncoder)
TransformerDecoder = blackbox_module(nn.TransformerDecoder)
TransformerEncoderLayer = blackbox_module(nn.TransformerEncoderLayer)
TransformerDecoderLayer = blackbox_module(nn.TransformerDecoderLayer)
Transformer = blackbox_module(nn.Transformer)
Flatten = blackbox_module(nn.Flatten)
Hardsigmoid = blackbox_module(nn.Hardsigmoid)
if version_larger_equal(torch.__version__, '1.6.0'):
Hardswish = blackbox_module(nn.Hardswish)
if version_larger_equal(torch.__version__, '1.7.0'):
SiLU = blackbox_module(nn.SiLU)
Unflatten = blackbox_module(nn.Unflatten)
TripletMarginWithDistanceLoss = blackbox_module(nn.TripletMarginWithDistanceLoss)