DARTS

Introduction

The paper DARTS: Differentiable Architecture Search addresses the scalability challenge of architecture search by formulating the task in a differentiable manner. Their method is based on the continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient descent.

Authors’ code optimizes the network weights and architecture weights alternatively in mini-batches. They further explore the possibility that uses second order optimization (unroll) instead of first order, to improve the performance.

Implementation on NNI is based on the official implementation and a popular 3rd-party repo. DARTS on NNI is designed to be general for arbitrary search space. A CNN search space tailored for CIFAR10, same as the original paper, is implemented as a use case of DARTS.

Reproduction Results

The above-mentioned example is meant to reproduce the results in the paper, we do experiments with first and second order optimization. Due to the time limit, we retrain only the best architecture derived from the search phase and we repeat the experiment only once. Our results is currently on par with the results reported in paper. We will add more results later when ready.

In paper Reproduction
First order (CIFAR10) 3.00 +/- 0.14 2.78
Second order (CIFAR10) 2.76 +/- 0.09 2.89

Examples

CNN Search Space

Example code

# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder.
git clone https://github.com/Microsoft/nni.git

# search the best architecture
cd examples/nas/darts
python3 search.py

# train the best architecture
python3 retrain.py --arc-checkpoint ./checkpoints/epoch_49.json

Reference

PyTorch

class nni.nas.pytorch.darts.DartsTrainer(model, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, arc_learning_rate=0.0003, unrolled=False)[source]

DARTS trainer.

Parameters:
  • model (nn.Module) – PyTorch model to be trained.
  • loss (callable) – Receives logits and ground truth label, return a loss tensor.
  • metrics (callable) – Receives logits and ground truth label, return a dict of metrics.
  • optimizer (Optimizer) – The optimizer used for optimizing the model.
  • num_epochs (int) – Number of epochs planned for training.
  • dataset_train (Dataset) – Dataset for training. Will be split for training weights and architecture weights.
  • dataset_valid (Dataset) – Dataset for testing.
  • mutator (DartsMutator) – Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
  • batch_size (int) – Batch size.
  • workers (int) – Workers for data loading.
  • device (torch.device) – torch.device("cpu") or torch.device("cuda").
  • log_frequency (int) – Step count per logging.
  • callbacks (list of Callback) – list of callbacks to trigger at events.
  • arc_learning_rate (float) – Learning rate of architecture parameters.
  • unrolled (float) – True if using second order optimization, else first order optimization.
train_one_epoch(epoch)[source]

Train one epoch.

Parameters:epoch (int) – Epoch number starting from 0.
validate_one_epoch(epoch)[source]

Validate one epoch.

Parameters:epoch (int) – Epoch number starting from 0.
class nni.nas.pytorch.darts.DartsMutator(model)[source]

Connects the model in a DARTS (differentiable) way.

An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no op on this LayerChoice (namely a ZeroOp), in which case, every element in the exported choice list is false (not chosen).

All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based on keys in choose_from. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0.

It’s possible to cut branches by setting parameter choices in a particular position to -inf. After softmax, the value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the -inf location will be 0. Since manipulations with -inf will be nan, you need to handle the gradient update phase carefully.

choices

dict that maps keys of LayerChoices to weighted-connection float tensors.

Type:ParameterDict
sample_final()[source]

Override to implement this method to iterate over mutables and make decisions that is final for export and retraining.

Returns:A mapping from key of mutables to decisions.
Return type:dict

Override to implement this method to iterate over mutables and make decisions.

Returns:A mapping from key of mutables to decisions.
Return type:dict

Limitations

  • DARTS doesn’t support DataParallel and needs to be customized in order to support DistributedDataParallel.