ENAS¶
Introduction¶
The paper Efficient Neural Architecture Search via Parameter Sharing uses parameter sharing between child models to accelerate the NAS process. In ENAS, a controller learns to discover neural network architectures by searching for an optimal subgraph within a large computational graph. The controller is trained with policy gradient to select a subgraph that maximizes the expected reward on the validation set. Meanwhile the model corresponding to the selected subgraph is trained to minimize a canonical cross entropy loss.
Implementation on NNI is based on the official implementation in Tensorflow, including a generalpurpose Reinforcementlearning controller and a trainer that trains target network and this controller alternatively. Following paper, we have also implemented macro and micro search space on CIFAR10 to demonstrate how to use these trainers. Since code to train from scratch on NNI is not ready yet, reproduction results are currently unavailable.
Examples¶
CIFAR10 Macro/Micro Search Space¶
# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder.
git clone https://github.com/Microsoft/nni.git
# search the best architecture
cd examples/nas/enas
# search in macro search space
python3 search.py searchfor macro
# search in micro search space
python3 search.py searchfor micro
# view more options for search
python3 search.py h
Reference¶
PyTorch¶

class
nni.nas.pytorch.enas.
EnasTrainer
(model, loss, metrics, reward_function, optimizer, num_epochs, dataset_train, dataset_valid, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, child_steps=500, mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4, test_arc_per_epoch=1)[source]¶ ENAS trainer.
Parameters:  model (nn.Module) – PyTorch model to be trained.
 loss (callable) – Receives logits and ground truth label, return a loss tensor.
 metrics (callable) – Receives logits and ground truth label, return a dict of metrics.
 reward_function (callable) – Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
 optimizer (Optimizer) – The optimizer used for optimizing the model.
 num_epochs (int) – Number of epochs planned for training.
 dataset_train (Dataset) – Dataset for training. Will be split for training weights and architecture weights.
 dataset_valid (Dataset) – Dataset for testing.
 mutator (EnasMutator) – Use when customizing your own mutator or a mutator with customized parameters.
 batch_size (int) – Batch size.
 workers (int) – Workers for data loading.
 device (torch.device) –
torch.device("cpu")
ortorch.device("cuda")
.  log_frequency (int) – Step count per logging.
 callbacks (list of Callback) – list of callbacks to trigger at events.
 entropy_weight (float) – Weight of sample entropy loss.
 skip_weight (float) – Weight of skip penalty loss.
 baseline_decay (float) – Decay factor of baseline. New baseline will be equal to
baseline_decay * baseline_old + reward * (1  baseline_decay)
.  child_steps (int) – How many minibatches for model training per epoch.
 mutator_lr (float) – Learning rate for RL controller.
 mutator_steps_aggregate (int) – Number of steps that will be aggregated into one minibatch for RL controller.
 mutator_steps (int) – Number of minibatches for each epoch of RL controller learning.
 aux_weight (float) – Weight of auxiliary head loss.
aux_weight * aux_loss
will be added to total loss.  test_arc_per_epoch (int) – How many architectures are chosen for direct test after each epoch.

class
nni.nas.pytorch.enas.
EnasMutator
(model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction='sum')[source]¶ A mutator that mutates the graph with RL.
Parameters:  model (nn.Module) – PyTorch model.
 lstm_size (int) – Controller LSTM hidden units.
 lstm_num_layers (int) – Number of layers for stacked LSTM.
 tanh_constant (float) – Logits will be equal to
tanh_constant * tanh(logits)
. Don’t usetanh
if this value isNone
.  cell_exit_extra_step (bool) – If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
 skip_target (float) – Target probability that skipconnect will appear.
 temperature (float) – Temperature constant that divides the logits.
 branch_bias (float) – Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a
reduce
in its key, all its op choices that contains conv in their typename will receive a bias of+self.branch_bias
initially; while others receive a bias ofself.branch_bias
.  entropy_reduction (str) – Can be one of
sum
andmean
. How the entropy of multiinputchoice is reduced.