The paper Efficient Neural Architecture Search via Parameter Sharing uses parameter sharing between child models to accelerate the NAS process. In ENAS, a controller learns to discover neural network architectures by searching for an optimal subgraph within a large computational graph. The controller is trained with policy gradient to select a subgraph that maximizes the expected reward on the validation set. Meanwhile the model corresponding to the selected subgraph is trained to minimize a canonical cross entropy loss.

Implementation on NNI is based on the official implementation in Tensorflow, including a general-purpose Reinforcement-learning controller and a trainer that trains target network and this controller alternatively. Following paper, we have also implemented macro and micro search space on CIFAR10 to demonstrate how to use these trainers. Since code to train from scratch on NNI is not ready yet, reproduction results are currently unavailable.


CIFAR10 Macro/Micro Search Space

Example code

# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder.
git clone https://github.com/Microsoft/nni.git

# search the best architecture
cd examples/nas/enas

# search in macro search space
python3 search.py --search-for macro

# search in micro search space
python3 search.py --search-for micro

# view more options for search
python3 search.py -h



class nni.nas.pytorch.enas.EnasTrainer(model, loss, metrics, reward_function, optimizer, num_epochs, dataset_train, dataset_valid, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500, mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4, test_arc_per_epoch=1)[source]

ENAS trainer.

  • model (nn.Module) – PyTorch model to be trained.
  • loss (callable) – Receives logits and ground truth label, return a loss tensor.
  • metrics (callable) – Receives logits and ground truth label, return a dict of metrics.
  • reward_function (callable) – Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
  • optimizer (Optimizer) – The optimizer used for optimizing the model.
  • num_epochs (int) – Number of epochs planned for training.
  • dataset_train (Dataset) – Dataset for training. Will be split for training weights and architecture weights.
  • dataset_valid (Dataset) – Dataset for testing.
  • mutator (EnasMutator) – Use when customizing your own mutator or a mutator with customized parameters.
  • batch_size (int) – Batch size.
  • workers (int) – Workers for data loading.
  • device (torch.device) – torch.device("cpu") or torch.device("cuda").
  • log_frequency (int) – Step count per logging.
  • callbacks (list of Callback) – list of callbacks to trigger at events.
  • entropy_weight (float) – Weight of sample entropy loss.
  • skip_weight (float) – Weight of skip penalty loss.
  • baseline_decay (float) – Decay factor of baseline. New baseline will be equal to baseline_decay * baseline_old + reward * (1 - baseline_decay).
  • child_steps (int) – How many mini-batches for model training per epoch.
  • mutator_lr (float) – Learning rate for RL controller.
  • mutator_steps_aggregate (int) – Number of steps that will be aggregated into one mini-batch for RL controller.
  • mutator_steps (int) – Number of mini-batches for each epoch of RL controller learning.
  • aux_weight (float) – Weight of auxiliary head loss. aux_weight * aux_loss will be added to total loss.
  • test_arc_per_epoch (int) – How many architectures are chosen for direct test after each epoch.

Train one epoch.

Parameters:epoch (int) – Epoch number starting from 0.

Validate one epoch.

Parameters:epoch (int) – Epoch number starting from 0.
class nni.nas.pytorch.enas.EnasMutator(model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction='sum')[source]

A mutator that mutates the graph with RL.

  • model (nn.Module) – PyTorch model.
  • lstm_size (int) – Controller LSTM hidden units.
  • lstm_num_layers (int) – Number of layers for stacked LSTM.
  • tanh_constant (float) – Logits will be equal to tanh_constant * tanh(logits). Don’t use tanh if this value is None.
  • cell_exit_extra_step (bool) – If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
  • skip_target (float) – Target probability that skipconnect will appear.
  • temperature (float) – Temperature constant that divides the logits.
  • branch_bias (float) – Manual bias applied to make some operations more likely to be chosen. Currently this is implemented with a hardcoded match rule that aligns with original repo. If a mutable has a reduce in its key, all its op choices that contains conv in their typename will receive a bias of +self.branch_bias initially; while others receive a bias of -self.branch_bias.
  • entropy_reduction (str) – Can be one of sum and mean. How the entropy of multi-input-choice is reduced.

Override to implement this method to iterate over mutables and make decisions that is final for export and retraining.

Returns:A mapping from key of mutables to decisions.
Return type:dict

Override to implement this method to iterate over mutables and make decisions.

Returns:A mapping from key of mutables to decisions.
Return type:dict