The paper Efficient Neural Architecture Search via Parameter Sharing uses parameter sharing between child models to accelerate the NAS process. In ENAS, a controller learns to discover neural network architectures by searching for an optimal subgraph within a large computational graph. The controller is trained with policy gradient to select a subgraph that maximizes the expected reward on the validation set. Meanwhile the model corresponding to the selected subgraph is trained to minimize a canonical cross entropy loss.
Implementation on NNI is based on the official implementation in Tensorflow, including a general-purpose Reinforcement-learning controller and a trainer that trains target network and this controller alternatively. Following paper, we have also implemented macro and micro search space on CIFAR10 to demonstrate how to use these trainers. Since code to train from scratch on NNI is not ready yet, reproduction results are currently unavailable.
CIFAR10 Macro/Micro Search Space¶
# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder. git clone https://github.com/Microsoft/nni.git # search the best architecture cd examples/nas/oneshot/enas # search in macro search space python3 search.py --search-for macro # search in micro search space python3 search.py --search-for micro # view more options for search python3 search.py -h
- class nni.retiarii.oneshot.pytorch.EnasTrainer(model, loss, metrics, reward_function, optimizer, num_epochs, dataset, batch_size=64, workers=4, device=None, log_frequency=None, grad_clip=5.0, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None)
model (nn.Module) – PyTorch model to be trained.
loss (callable) – Receives logits and ground truth label, return a loss tensor.
metrics (callable) – Receives logits and ground truth label, return a dict of metrics.
reward_function (callable) – Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer (Optimizer) – The optimizer used for optimizing the model.
num_epochs (int) – Number of epochs planned for training.
dataset (Dataset) – Dataset for training. Will be split for training weights and architecture weights.
batch_size (int) – Batch size.
workers (int) – Workers for data loading.
device (torch.device) –
log_frequency (int) – Step count per logging.
grad_clip (float) – Gradient clipping. Set to 0 to disable. Default: 5.
entropy_weight (float) – Weight of sample entropy loss.
skip_weight (float) – Weight of skip penalty loss.
baseline_decay (float) – Decay factor of baseline. New baseline will be equal to
baseline_decay * baseline_old + reward * (1 - baseline_decay).
ctrl_lr (float) – Learning rate for RL controller.
ctrl_steps_aggregate (int) – Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_steps (int) – Number of mini-batches for each epoch of RL controller learning.
ctrl_kwargs (dict) – Optional kwargs that will be passed to