Source code for nni.retiarii.oneshot.pytorch.proxyless

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, to_device


_logger = logging.getLogger(__name__)


class ArchGradientFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, binary_gates, run_func, backward_func):
        ctx.run_func = run_func
        ctx.backward_func = backward_func

        detached_x = x.detach()
        detached_x.requires_grad = x.requires_grad
        with torch.enable_grad():
            output = run_func(detached_x)
        ctx.save_for_backward(detached_x, output)
        return output.data

    @staticmethod
    def backward(ctx, grad_output):
        detached_x, output = ctx.saved_tensors

        grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
        # compute gradients w.r.t. binary_gates
        binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)

        return grad_x[0], binary_grads, None, None


class ProxylessLayerChoice(nn.Module):
    def __init__(self, ops):
        super(ProxylessLayerChoice, self).__init__()
        self.ops = nn.ModuleList(ops)
        self.alpha = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
        self._binary_gates = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
        self.sampled = None

    def forward(self, *args):
        def run_function(ops, active_id):
            def forward(_x):
                return ops[active_id](_x)
            return forward

        def backward_function(ops, active_id, binary_gates):
            def backward(_x, _output, grad_output):
                binary_grads = torch.zeros_like(binary_gates.data)
                with torch.no_grad():
                    for k in range(len(ops)):
                        if k != active_id:
                            out_k = ops[k](_x.data)
                        else:
                            out_k = _output.data
                        grad_k = torch.sum(out_k * grad_output)
                        binary_grads[k] = grad_k
                return binary_grads
            return backward

        assert len(args) == 1
        x = args[0]
        return ArchGradientFunction.apply(
            x, self._binary_gates, run_function(self.ops, self.sampled),
            backward_function(self.ops, self.sampled, self._binary_gates)
        )

    def resample(self):
        probs = F.softmax(self.alpha, dim=-1)
        sample = torch.multinomial(probs, 1)[0].item()
        self.sampled = sample
        with torch.no_grad():
            self._binary_gates.zero_()
            self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
            self._binary_gates.data[sample] = 1.0

    def finalize_grad(self):
        binary_grads = self._binary_gates.grad
        with torch.no_grad():
            if self.alpha.grad is None:
                self.alpha.grad = torch.zeros_like(self.alpha.data)
            probs = F.softmax(self.alpha, dim=-1)
            for i in range(len(self.ops)):
                for j in range(len(self.ops)):
                    self.alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])

    def export(self):
        return torch.argmax(self.alpha).item()

    def export_prob(self):
        return F.softmax(self.alpha, dim=-1)


class ProxylessInputChoice(nn.Module):
    def __init__(self, *args, **kwargs):
        raise NotImplementedError('Input choice is not supported for ProxylessNAS.')


class HardwareLatencyEstimator:
    def __init__(self, applied_hardware, model, dummy_input=(1, 3, 224, 224), dump_lat_table='data/latency_table.yaml'):
        import nn_meter  # pylint: disable=import-error
        _logger.info(f'Load latency predictor for applied hardware: {applied_hardware}.')
        self.predictor_name = applied_hardware
        self.latency_predictor = nn_meter.load_latency_predictor(applied_hardware)
        self.block_latency_table = self._form_latency_table(model, dummy_input, dump_lat_table=dump_lat_table)

    def _form_latency_table(self, model, dummy_input, dump_lat_table):
        latency_table = {}

        from nni.retiarii.converter import convert_to_graph
        from nni.retiarii.converter.graph_gen import GraphConverterWithShape
        from nni.retiarii.converter.utils import flatten_model_graph_without_layerchoice, is_layerchoice_node
        script_module = torch.jit.script(model)
        base_model_ir = convert_to_graph(script_module, model,
                                         converter=GraphConverterWithShape(), dummy_input=torch.randn(*dummy_input))

        # form the latency of layerchoice blocks for the latency table
        temp_ir_model = base_model_ir.fork()
        cell_nodes = base_model_ir.get_cell_nodes()
        layerchoice_nodes = [node for node in cell_nodes if is_layerchoice_node(node)]
        for lc_node in layerchoice_nodes:
            cand_lat = {}
            for candidate in lc_node.operation.parameters['candidates']:
                node_graph = base_model_ir.graphs.get(candidate)
                if node_graph is not None:
                    temp_ir_model._root_graph_name = node_graph.name
                    latency = self.latency_predictor.predict(temp_ir_model, model_type = 'nni-ir')
                else:
                    _logger.warning(f"Could not found graph for layerchoice candidate {candidate}")
                    latency = 0
                cand_lat[candidate.split('_')[-1]] = float(latency)
            latency_table[lc_node.operation.parameters['label']] = cand_lat

        # form the latency of the stationary block in the latency table
        temp_ir_model._root_graph_name = base_model_ir._root_graph_name
        temp_ir_model = flatten_model_graph_without_layerchoice(temp_ir_model)
        latency = self.latency_predictor.predict(temp_ir_model, model_type = 'nni-ir')
        latency_table['stationary_block'] = {'root': float(latency)}

        # save latency table
        if dump_lat_table:
            import os, yaml
            os.makedirs(os.path.dirname(dump_lat_table), exist_ok=True)
            with open(dump_lat_table, 'a') as fp:
                yaml.dump([{
                    "applied_hardware": self.predictor_name,
                    'latency_table': latency_table
                    }], fp)
        _logger.info("Latency lookup table form done")

        return latency_table

    def cal_expected_latency(self, current_architecture_prob):
        lat = self.block_latency_table['stationary_block']['root']
        for module_name, probs in current_architecture_prob.items():
            assert len(probs) == len(self.block_latency_table[module_name])
            lat += torch.sum(torch.tensor([probs[i] * self.block_latency_table[module_name][str(i)]
                                for i in range(len(probs))]))
        return lat

    def export_latency(self, current_architecture):
        lat = self.block_latency_table['stationary_block']['root']
        for module_name, selected_module in current_architecture.items():
            lat += self.block_latency_table[module_name][str(selected_module)]
        return lat


[docs]class ProxylessTrainer(BaseOneShotTrainer): """ Proxyless trainer. Parameters ---------- model : nn.Module PyTorch model to be trained. loss : callable Receives logits and ground truth label, return a loss tensor. metrics : callable Receives logits and ground truth label, return a dict of metrics. optimizer : Optimizer The optimizer used for optimizing the model. num_epochs : int Number of epochs planned for training. dataset : Dataset Dataset for training. Will be split for training weights and architecture weights. warmup_epochs : int Number of epochs to warmup model parameters. batch_size : int Batch size. workers : int Workers for data loading. device : torch.device ``torch.device("cpu")`` or ``torch.device("cuda")``. log_frequency : int Step count per logging. arc_learning_rate : float Learning rate of architecture parameters. grad_reg_loss_type: string Regularization type to add hardware related loss, allowed types include - ``"mul#log"``: ``regularized_loss = (torch.log(expected_latency) / math.log(self.ref_latency)) ** beta`` - ``"add#linear"``: ``regularized_loss = reg_lambda * (expected_latency - self.ref_latency) / self.ref_latency`` - None: do not apply loss regularization. grad_reg_loss_params: dict Regularization params, allowed params include - ``"alpha"`` and ``"beta"`` is required when ``grad_reg_loss_type == "mul#log"`` - ``"lambda"`` is required when ``grad_reg_loss_type == "add#linear"`` applied_hardware: string Applied hardware for to constraint the model's latency. Latency is predicted by Microsoft nn-Meter (https://github.com/microsoft/nn-Meter). dummy_input: tuple The dummy input shape when applied to the target hardware. ref_latency: float Reference latency value in the applied hardware (ms). """ def __init__(self, model, loss, metrics, optimizer, num_epochs, dataset, warmup_epochs=0, batch_size=64, workers=4, device=None, log_frequency=None, arc_learning_rate=1.0E-3, grad_reg_loss_type=None, grad_reg_loss_params=None, applied_hardware=None, dummy_input=(1, 3, 224, 224), ref_latency=65.0): self.model = model self.loss = loss self.metrics = metrics self.optimizer = optimizer self.num_epochs = num_epochs self.warmup_epochs = warmup_epochs self.dataset = dataset self.batch_size = batch_size self.workers = workers self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device self.log_frequency = log_frequency # latency predictor if applied_hardware: self.latency_estimator = HardwareLatencyEstimator(applied_hardware, self.model, dummy_input) else: self.latency_estimator = None self.reg_loss_type = grad_reg_loss_type self.reg_loss_params = {} if grad_reg_loss_params is None else grad_reg_loss_params self.ref_latency = ref_latency self.model.to(self.device) self.nas_modules = [] replace_layer_choice(self.model, ProxylessLayerChoice, self.nas_modules) replace_input_choice(self.model, ProxylessInputChoice, self.nas_modules) for _, module in self.nas_modules: module.to(self.device) self.optimizer = optimizer # we do not support deduplicate control parameters with same label (like DARTS) yet. self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate, weight_decay=0, betas=(0, 0.999), eps=1e-8) self._init_dataloader() def _init_dataloader(self): n_train = len(self.dataset) split = n_train // 2 indices = list(range(n_train)) train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) self.train_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, sampler=train_sampler, num_workers=self.workers) self.valid_loader = torch.utils.data.DataLoader(self.dataset, batch_size=self.batch_size, sampler=valid_sampler, num_workers=self.workers) def _train_one_epoch(self, epoch): self.model.train() meters = AverageMeterGroup() for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): trn_X, trn_y = to_device(trn_X, self.device), to_device(trn_y, self.device) val_X, val_y = to_device(val_X, self.device), to_device(val_y, self.device) if epoch >= self.warmup_epochs: # 1) train architecture parameters for _, module in self.nas_modules: module.resample() self.ctrl_optim.zero_grad() logits, loss = self._logits_and_loss_for_arch_update(val_X, val_y) loss.backward() for _, module in self.nas_modules: module.finalize_grad() self.ctrl_optim.step() # 2) train model parameters for _, module in self.nas_modules: module.resample() self.optimizer.zero_grad() logits, loss = self._logits_and_loss_for_weight_update(trn_X, trn_y) loss.backward() self.optimizer.step() metrics = self.metrics(logits, trn_y) metrics["loss"] = loss.item() if self.latency_estimator: metrics["latency"] = self._export_latency() meters.update(metrics) if self.log_frequency is not None and step % self.log_frequency == 0: _logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs, step + 1, len(self.train_loader), meters) def _logits_and_loss_for_arch_update(self, X, y): ''' return logits and loss for architecture parameter update ''' logits = self.model(X) ce_loss = self.loss(logits, y) if not self.latency_estimator: return logits, ce_loss current_architecture_prob = {} for module_name, module in self.nas_modules: probs = module.export_prob() current_architecture_prob[module_name] = probs expected_latency = self.latency_estimator.cal_expected_latency(current_architecture_prob) if self.reg_loss_type == 'mul#log': import math alpha = self.reg_loss_params.get('alpha', 1) beta = self.reg_loss_params.get('beta', 0.6) # noinspection PyUnresolvedReferences reg_loss = (torch.log(expected_latency) / math.log(self.ref_latency)) ** beta return logits, alpha * ce_loss * reg_loss elif self.reg_loss_type == 'add#linear': reg_lambda = self.reg_loss_params.get('lambda', 2e-1) reg_loss = reg_lambda * (expected_latency - self.ref_latency) / self.ref_latency return logits, ce_loss + reg_loss elif self.reg_loss_type is None: return logits, ce_loss else: raise ValueError(f'Do not support: {self.reg_loss_type}') def _logits_and_loss_for_weight_update(self, X, y): ''' return logits and loss for weight parameter update ''' logits = self.model(X) loss = self.loss(logits, y) return logits, loss def _export_latency(self): current_architecture = {} for module_name, module in self.nas_modules: selected_module = module.export() current_architecture[module_name] = selected_module return self.latency_estimator.export_latency(current_architecture) def fit(self): for i in range(self.num_epochs): self._train_one_epoch(i) @torch.no_grad() def export(self): result = dict() for name, module in self.nas_modules: if name not in result: result[name] = module.export() return result