Source code for nni.nas.pytorch.callbacks

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import os

import torch
import torch.nn as nn

_logger = logging.getLogger(__name__)


[docs]class Callback: """ Callback provides an easy way to react to events like begin/end of epochs. """ def __init__(self): self.model = None self.mutator = None self.trainer = None
[docs] def build(self, model, mutator, trainer): """ Callback needs to be built with model, mutator, trainer, to get updates from them. Parameters ---------- model : nn.Module Model to be trained. mutator : nn.Module Mutator that mutates the model. trainer : BaseTrainer Trainer that is to call the callback. """ self.model = model self.mutator = mutator self.trainer = trainer
[docs] def on_epoch_begin(self, epoch): """ Implement this to do something at the begin of epoch. Parameters ---------- epoch : int Epoch number, starting from 0. """ pass
[docs] def on_epoch_end(self, epoch): """ Implement this to do something at the end of epoch. Parameters ---------- epoch : int Epoch number, starting from 0. """ pass
def on_batch_begin(self, epoch): pass def on_batch_end(self, epoch): pass
[docs]class LRSchedulerCallback(Callback): """ Calls scheduler on every epoch ends. Parameters ---------- scheduler : LRScheduler Scheduler to be called. """ def __init__(self, scheduler, mode="epoch"): super().__init__() assert mode == "epoch" self.scheduler = scheduler self.mode = mode
[docs] def on_epoch_end(self, epoch): """ Call ``self.scheduler.step()`` on epoch end. """ self.scheduler.step()
[docs]class ArchitectureCheckpoint(Callback): """ Calls ``trainer.export()`` on every epoch ends. Parameters ---------- checkpoint_dir : str Location to save checkpoints. """ def __init__(self, checkpoint_dir): super().__init__() self.checkpoint_dir = checkpoint_dir os.makedirs(self.checkpoint_dir, exist_ok=True)
[docs] def on_epoch_end(self, epoch): """ Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end. """ dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)) _logger.info("Saving architecture to %s", dest_path) self.trainer.export(dest_path)
[docs]class ModelCheckpoint(Callback): """ Calls ``trainer.export()`` on every epoch ends. Parameters ---------- checkpoint_dir : str Location to save checkpoints. """ def __init__(self, checkpoint_dir): super().__init__() self.checkpoint_dir = checkpoint_dir os.makedirs(self.checkpoint_dir, exist_ok=True)
[docs] def on_epoch_end(self, epoch): """ Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end. ``DataParallel`` object will have their inside modules exported. """ if isinstance(self.model, nn.DataParallel): state_dict = self.model.module.state_dict() else: state_dict = self.model.state_dict() dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch)) _logger.info("Saving model to %s", dest_path) torch.save(state_dict, dest_path)