The paper DARTS: Differentiable Architecture Search addresses the scalability challenge of architecture search by formulating the task in a differentiable manner. Their method is based on the continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient descent.
Authors’ code optimizes the network weights and architecture weights alternatively in mini-batches. They further explore the possibility that uses second order optimization (unroll) instead of first order, to improve the performance.
Implementation on NNI is based on the official implementation and a popular 3rd-party repo. DARTS on NNI is designed to be general for arbitrary search space. A CNN search space tailored for CIFAR10, same as the original paper, is implemented as a use case of DARTS.
The above-mentioned example is meant to reproduce the results in the paper, we do experiments with first and second order optimization. Due to the time limit, we retrain only the best architecture derived from the search phase and we repeat the experiment only once. Our results is currently on par with the results reported in paper. We will add more results later when ready.
First order (CIFAR10)
3.00 +/- 0.14
Second order (CIFAR10)
2.76 +/- 0.09
CNN Search Space¶
# In case NNI code is not cloned. If the code is cloned already, ignore this line and enter code folder. git clone https://github.com/Microsoft/nni.git # search the best architecture cd examples/nas/oneshot/darts python3 search.py # train the best architecture python3 retrain.py --arc-checkpoint ./checkpoints/epoch_49.json
- class nni.retiarii.oneshot.pytorch.DartsTrainer(model, loss, metrics, optimizer, num_epochs, dataset, grad_clip=5.0, learning_rate=0.0025, batch_size=64, workers=4, device=None, log_frequency=None, arc_learning_rate=0.0003, unrolled=False)
model (nn.Module) – PyTorch model to be trained.
loss (callable) – Receives logits and ground truth label, return a loss tensor.
metrics (callable) – Receives logits and ground truth label, return a dict of metrics.
optimizer (Optimizer) – The optimizer used for optimizing the model.
num_epochs (int) – Number of epochs planned for training.
dataset (Dataset) – Dataset for training. Will be split for training weights and architecture weights.
grad_clip (float) – Gradient clipping. Set to 0 to disable. Default: 5.
learning_rate (float) – Learning rate to optimize the model.
batch_size (int) – Batch size.
workers (int) – Workers for data loading.
device (torch.device) –
log_frequency (int) – Step count per logging.
arc_learning_rate (float) – Learning rate of architecture parameters.
unrolled (float) –
Trueif using second order optimization, else first order optimization.
DARTS doesn’t support DataParallel and needs to be customized in order to support DistributedDataParallel.