Customize A New Trainer

Trainers are necessary to evaluate the performance of new explored models. In NAS scenario, this further divides into two use cases:

  1. Classic trainers: trainers that are used to train and evaluate one single model.

  2. One-shot trainers: trainers that handle training and searching simultaneously, from an end-to-end perspective.

Classic trainers

All classic trainers need to inherit nni.retiarii.trainer.BaseTrainer, implement the fit method and decorated with @register_trainer if it is intended to be used together with Retiarii. The decorator serialize the trainer that is used and its argument to fit for the requirements of NNI.

The init function of trainer should take model as its first argument, and the rest of the arguments should be named (*args and **kwargs may not work as expected) and JSON serializable. This means, currently, passing a complex object like torchvision.datasets.ImageNet() is not supported. Trainer should use NNI standard API to communicate with tuning algorithms. This includes nni.report_intermediate_result for periodical metrics and nni.report_final_result for final metrics.

An example is as follows:

One-shot trainers

One-shot trainers should inheirt nni.retiarii.trainer.BaseOneShotTrainer, which is basically same as BaseTrainer, but only with one extra method export(), which is expected to return the searched best architecture.

Writing a one-shot trainer is very different to classic trainers. First of all, there are no more restrictions on init method arguments, any Python arguments are acceptable. Secondly, the model feeded into one-shot trainers might be a model with Retiarii-specific modules, such as LayerChoice and InputChoice. Such model cannot directly forward-propagate and trainers need to decide how to handle those modules.

A typical example is DartsTrainer, where learnable-parameters are used to combine multiple choices in LayerChoice. Retiarii provides ease-to-use utility functions for module-replace purposes, namely replace_layer_choice, replace_input_choice. A simplified example is as follows:

The full code of DartsTrainer is available to Retiarii source code. Please have a check at Github link: nni/retiarii/trainer/pytorch/darts.py.