Model Evaluator¶
A model evaluator is for training and validating each generated model. They are necessary to evaluate the performance of new explored models.
Customize Evaluator with Any Function¶
The simplest way to customize a new evaluator is with FunctionalEvaluator
, which is very easy when training code is already available. Users only need to write a fit function that wraps everything, which usually includes training, validating and testing of a single model. This function takes one positional arguments (model_cls
) and possible keyword arguments. The keyword arguments (other than model_cls
) are fed to FunctionalEvaluator
as its initialization parameters (note that they will be serialized). In this way, users get everything under their control, but expose less information to the framework and as a result, further optimizations like CGO might be not feasible. An example is as belows:
from nni.retiarii.evaluator import FunctionalEvaluator
from nni.retiarii.experiment.pytorch import RetiariiExperiment
def fit(model_cls, dataloader):
model = model_cls()
train(model, dataloader)
acc = test(model, dataloader)
nni.report_final_result(acc)
# The dataloader will be serialized, thus ``nni.trace`` is needed here.
# See serialization tutorial for more details.
evaluator = FunctionalEvaluator(fit, dataloader=nni.trace(DataLoader)(foo, bar))
experiment = RetiariiExperiment(base_model, evaluator, mutators, strategy)
小技巧
When using customized evaluators, if you want to visualize models, you need to export your model and save it into $NNI_OUTPUT_DIR/model.onnx
in your evaluator. An example here:
def fit(model_cls):
model = model_cls()
onnx_path = Path(os.environ.get('NNI_OUTPUT_DIR', '.')) / 'model.onnx'
onnx_path.parent.mkdir(exist_ok=True)
dummy_input = torch.randn(10, 3, 224, 224)
torch.onnx.export(model, dummy_input, onnx_path)
# the rest of training code here
If the conversion is successful, the model will be able to be visualized with powerful tools Netron.
Use Evaluators to Train and Evaluate Models¶
Users can use evaluators to train or evaluate a single, concrete architecture. This is very useful when:
Debugging your evaluator against a baseline model.
Fully train, validate and test your model after the search process is complete.
The usage is shown below:
# Class definition of single model, for example, ResNet.
class SingleModel(nn.Module):
def __init__(): # Can't have init parameters here.
...
# Use a callable returning a model
evaluator.evaluate(SingleModel)
# Or initialize the model beforehand
evaluator.evaluate(SingleModel())
The underlying implementation of evaluate()
depends on concrete evaluator that you used.
For example, if FunctionalEvaluator
is used, it will run your customized fit function.
If lightning evaluators like nni.retiarii.evaluator.pytorch.Classification
are used, it will invoke the trainer.fit()
of Lightning.
To evaluate an architecture that is exported from experiment (i.e., from export_top_models()
), use nni.retiarii.fixed_arch()
to instantiate the exported model:
with fixed_arch(exported_model):
model = ModelSpace()
# Then use evaluator.evaluate
evaluator.evaluate(model)
小技巧
There is a way to port the trained checkpoint of super-net produced by one-shot strategies, to the concrete chosen architecture, thanks to nni.retiarii.utils.original_state_dict_hooks()
. This is helpful in implementing recent multi-stage NAS algorithms like SPOS.
Evaluators with PyTorch-Lightning¶
Use Built-in Evaluators¶
NNI provides some commonly used model evaluators for users' convenience. These evaluators are built upon the awesome library PyTorch-Lightning. Read the reference for their detailed usages.
nni.retiarii.evaluator.pytorch.Classification
: for classification tasks.nni.retiarii.evaluator.pytorch.Regression
: for regression tasks.
We recommend to read the serialization tutorial before using these evaluators. A few notes to summarize the tutorial:
nni.retiarii.evaluator.pytorch.DataLoader
should be used in place oftorch.utils.data.DataLoader
.The datasets used in data-loader should be decorated with
nni.trace()
recursively.
For example,
import nni.retiarii.evaluator.pytorch.lightning as pl
from torchvision import transforms
transform = nni.trace(transforms.Compose, [nni.trace(transforms.ToTensor()), nni.trace(transforms.Normalize, (0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST, root='data/mnist', train=True, download=True, transform=transform)
test_dataset = nni.trace(MNIST, root='data/mnist', train=False, download=True, transform=transform)
# pl.DataLoader and pl.Classification is already traced and supports serialization.
evaluator = pl.Classification(train_dataloaders=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
max_epochs=10)
Customize Evaluator with PyTorch-Lightning¶
Another approach is to write training code in PyTorch-Lightning style, that is, to write a LightningModule that defines all elements needed for training (e.g., loss function, optimizer) and to define a trainer that takes (optional) dataloaders to execute the training. Before that, please read the document of PyTorch-lightning to learn the basic concepts and components provided by PyTorch-lightning.
In practice, writing a new training module in Retiarii should inherit nni.retiarii.evaluator.pytorch.LightningModule
, which has a set_model
that will be called after __init__
to save the candidate model (generated by strategy) as self.model
. The rest of the process (like training_step
) should be the same as writing any other lightning module. Evaluators should also communicate with strategies via two API calls (nni.report_intermediate_result()
for periodical metrics and nni.report_final_result()
for final metrics), added in on_validation_epoch_end
and teardown
respectively.
An example is as follows:
from nni.retiarii.evaluator.pytorch.lightning import LightningModule # please import this one
@nni.trace
class AutoEncoder(LightningModule):
def __init__(self):
super().__init__()
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28*28)
)
def forward(self, x):
embedding = self.model(x) # let's search for encoder
return embedding
def training_step(self, batch, batch_idx):
# training_step defined the train loop.
# It is independent of forward
x, y = batch
x = x.view(x.size(0), -1)
z = self.model(x) # model is the one that is searched for
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
# Logging to TensorBoard by default
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.model(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('val_loss', loss)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def on_validation_epoch_end(self):
nni.report_intermediate_result(self.trainer.callback_metrics['val_loss'].item())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_loss'].item())
备注
If you are trying to use your customized evaluator with one-shot strategy, bear in mind that your defined methods will be reassembled into another LightningModule, which might result in extra constraints when writing the LightningModule. For example, your validation step could appear else where (e.g., in training_step
). This prohibits you from returning arbitrary object in validation_step
.
Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a nni.retiarii.evaluator.pytorch.Lightning
object, and pass this object into a Retiarii experiment.
import nni.retiarii.evaluator.pytorch.lightning as pl
from nni.retiarii.experiment.pytorch import RetiariiExperiment
lightning = pl.Lightning(AutoEncoder(),
pl.Trainer(max_epochs=10),
train_dataloaders=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100))
experiment = RetiariiExperiment(base_model, lightning, mutators, strategy)