Source code for nni.nas.pytorch.enas.mutator

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
import torch.nn as nn
import torch.nn.functional as F

from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope


class StackedLSTMCell(nn.Module):
    def __init__(self, layers, size, bias):
        super().__init__()
        self.lstm_num_layers = layers
        self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
                                           for _ in range(self.lstm_num_layers)])

    def forward(self, inputs, hidden):
        prev_c, prev_h = hidden
        next_c, next_h = [], []
        for i, m in enumerate(self.lstm_modules):
            curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
            next_c.append(curr_c)
            next_h.append(curr_h)
            # current implementation only supports batch size equals 1,
            # but the algorithm does not necessarily have this limitation
            inputs = curr_h[-1].view(1, -1)
        return next_c, next_h


[docs]class EnasMutator(Mutator): """ A mutator that mutates the graph with RL. Parameters ---------- model : nn.Module PyTorch model. lstm_size : int Controller LSTM hidden units. lstm_num_layers : int Number of layers for stacked LSTM. tanh_constant : float Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``. cell_exit_extra_step : bool If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper. skip_target : float Target probability that skipconnect will appear. temperature : float Temperature constant that divides the logits. branch_bias : float Manual bias applied to make some operations more likely to be chosen. Currently this is implemented with a hardcoded match rule that aligns with original repo. If a mutable has a ``reduce`` in its key, all its op choices that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others receive a bias of ``-self.branch_bias``. entropy_reduction : str Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced. """ def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"): super().__init__(model) self.lstm_size = lstm_size self.lstm_num_layers = lstm_num_layers self.tanh_constant = tanh_constant self.temperature = temperature self.cell_exit_extra_step = cell_exit_extra_step self.skip_target = skip_target self.branch_bias = branch_bias self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean." self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") self.bias_dict = nn.ParameterDict() self.max_layer_choice = 0 for mutable in self.mutables: if isinstance(mutable, LayerChoice): if self.max_layer_choice == 0: self.max_layer_choice = len(mutable) assert self.max_layer_choice == len(mutable), \ "ENAS mutator requires all layer choice have the same number of candidates." # We are judging by keys and module types to add biases to layer choices. Needs refactor. if "reduce" in mutable.key: def is_conv(choice): return "conv" in str(type(choice)).lower() bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable for choice in mutable]) self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False) self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
[docs] def sample_final(self): return self.sample_search()
def _sample(self, tree): mutable = tree.mutable if isinstance(mutable, LayerChoice) and mutable.key not in self._choices: self._choices[mutable.key] = self._sample_layer_choice(mutable) elif isinstance(mutable, InputChoice) and mutable.key not in self._choices: self._choices[mutable.key] = self._sample_input_choice(mutable) for child in tree.children: self._sample(child) if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid: if self.cell_exit_extra_step: self._lstm_next_step() self._mark_anchor(mutable.key) def _initialize(self): self._choices = dict() self._anchors_hid = dict() self._inputs = self.g_emb.data self._c = [torch.zeros((1, self.lstm_size), dtype=self._inputs.dtype, device=self._inputs.device) for _ in range(self.lstm_num_layers)] self._h = [torch.zeros((1, self.lstm_size), dtype=self._inputs.dtype, device=self._inputs.device) for _ in range(self.lstm_num_layers)] self.sample_log_prob = 0 self.sample_entropy = 0 self.sample_skip_penalty = 0 def _lstm_next_step(self): self._c, self._h = self.lstm(self._inputs, (self._c, self._h)) def _mark_anchor(self, key): self._anchors_hid[key] = self._h[-1] def _sample_layer_choice(self, mutable): self._lstm_next_step() logit = self.soft(self._h[-1]) if self.temperature is not None: logit /= self.temperature if self.tanh_constant is not None: logit = self.tanh_constant * torch.tanh(logit) if mutable.key in self.bias_dict: logit += self.bias_dict[mutable.key] branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) log_prob = self.cross_entropy_loss(logit, branch_id) self.sample_log_prob += self.entropy_reduction(log_prob) entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type self.sample_entropy += self.entropy_reduction(entropy) self._inputs = self.embedding(branch_id) return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) def _sample_input_choice(self, mutable): query, anchors = [], [] for label in mutable.choose_from: if label not in self._anchors_hid: self._lstm_next_step() self._mark_anchor(label) # empty loop, fill not found query.append(self.attn_anchor(self._anchors_hid[label])) anchors.append(self._anchors_hid[label]) query = torch.cat(query, 0) query = torch.tanh(query + self.attn_query(self._h[-1])) query = self.v_attn(query) if self.temperature is not None: query /= self.temperature if self.tanh_constant is not None: query = self.tanh_constant * torch.tanh(query) if mutable.n_chosen is None: logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) skip_prob = torch.sigmoid(logit) kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets)) self.sample_skip_penalty += kl log_prob = self.cross_entropy_loss(logit, skip) self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0) else: assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS." logit = query.view(1, -1) index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1) log_prob = self.cross_entropy_loss(logit, index) self._inputs = anchors[index.item()] self.sample_log_prob += self.entropy_reduction(log_prob) entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type self.sample_entropy += self.entropy_reduction(entropy) return skip.bool()