Source code for nni.nas.pytorch.base_trainer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from abc import ABC, abstractmethod


[docs]class BaseTrainer(ABC):
[docs] @abstractmethod def train(self): """ Override the method to train. """ raise NotImplementedError
[docs] @abstractmethod def validate(self): """ Override the method to validate. """ raise NotImplementedError
[docs] @abstractmethod def export(self, file): """ Override the method to export to file. Parameters ---------- file : str File path to export to. """ raise NotImplementedError
[docs] @abstractmethod def checkpoint(self): """ Override to dump a checkpoint. """ raise NotImplementedError