# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, to_device
_logger = logging.getLogger(__name__)
class ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = x.detach()
detached_x.requires_grad = x.requires_grad
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class ProxylessLayerChoice(nn.Module):
def __init__(self, ops):
super(ProxylessLayerChoice, self).__init__()
self.ops = nn.ModuleList(ops)
self.alpha = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
self._binary_gates = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
self.sampled = None
def forward(self, *args):
def run_function(ops, active_id):
def forward(_x):
return ops[active_id](_x)
return forward
def backward_function(ops, active_id, binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1
x = args[0]
return ArchGradientFunction.apply(
x, self._binary_gates, run_function(self.ops, self.sampled),
backward_function(self.ops, self.sampled, self._binary_gates)
)
def resample(self):
probs = F.softmax(self.alpha, dim=-1)
sample = torch.multinomial(probs, 1)[0].item()
self.sampled = sample
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0
def finalize_grad(self):
binary_grads = self._binary_gates.grad
with torch.no_grad():
if self.alpha.grad is None:
self.alpha.grad = torch.zeros_like(self.alpha.data)
probs = F.softmax(self.alpha, dim=-1)
for i in range(len(self.ops)):
for j in range(len(self.ops)):
self.alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
def export(self):
return torch.argmax(self.alpha).item()
class ProxylessInputChoice(nn.Module):
def __init__(self, *args, **kwargs):
raise NotImplementedError('Input choice is not supported for ProxylessNAS.')
[docs]class ProxylessTrainer(BaseOneShotTrainer):
"""
Proxyless trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
warmup_epochs : int
Number of epochs to warmup model parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
arc_learning_rate : float
Learning rate of architecture parameters.
"""
def __init__(self, model, loss, metrics, optimizer,
num_epochs, dataset, warmup_epochs=0,
batch_size=64, workers=4, device=None, log_frequency=None,
arc_learning_rate=1.0E-3):
self.model = model
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.num_epochs = num_epochs
self.warmup_epochs = warmup_epochs
self.dataset = dataset
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.model.to(self.device)
self.nas_modules = []
replace_layer_choice(self.model, ProxylessLayerChoice, self.nas_modules)
replace_input_choice(self.model, ProxylessInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.optimizer = optimizer
# we do not support deduplicate control parameters with same label (like DARTS) yet.
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate,
weight_decay=0, betas=(0, 0.999), eps=1e-8)
self._init_dataloader()
def _init_dataloader(self):
n_train = len(self.dataset)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
def _train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = to_device(trn_X, self.device), to_device(trn_y, self.device)
val_X, val_y = to_device(val_X, self.device), to_device(val_y, self.device)
if epoch >= self.warmup_epochs:
# 1) train architecture parameters
for _, module in self.nas_modules:
module.resample()
self.ctrl_optim.zero_grad()
logits, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
for _, module in self.nas_modules:
module.finalize_grad()
self.ctrl_optim.step()
# 2) train model parameters
for _, module in self.nas_modules:
module.resample()
self.optimizer.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _logits_and_loss(self, X, y):
logits = self.model(X)
loss = self.loss(logits, y)
return logits, loss
def fit(self):
for i in range(self.num_epochs):
self._train_one_epoch(i)
@torch.no_grad()
def export(self):
result = dict()
for name, module in self.nas_modules:
if name not in result:
result[name] = module.export()
return result