Source code for nni.retiarii.oneshot.pytorch.strategy

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Strategy integration of one-shot.

This file is put here simply because it relies on "pytorch".
For consistency, please consider importing strategies from ``nni.retiarii.strategy``.
For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course).

When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
"""

from __future__ import annotations

import warnings
from typing import Any, Type

import torch.nn as nn

from nni.retiarii.graph import Model
from nni.retiarii.strategy.base import BaseStrategy
from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule

from .base_lightning import BaseOneShotLightningModule
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasLightningModule, RandomSamplingLightningModule


class OneShotStrategy(BaseStrategy):
    """Wrap an one-shot lightning module as a one-shot strategy."""

    def __init__(self, oneshot_module: Type[BaseOneShotLightningModule], **kwargs):
        self.oneshot_module = oneshot_module
        self.oneshot_kwargs = kwargs

        self.model: BaseOneShotLightningModule | None = None

    def preprocess_dataloader(self, train_dataloaders: Any, val_dataloaders: Any) -> tuple[Any, Any]:
        """
        One-shot strategy typically requires fusing train and validation dataloader in an ad-hoc way.
        As one-shot strategy doesn't try to open the blackbox of a batch,
        theoretically, these dataloader can be
        `any dataloader types supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.

        Returns
        -------
        A tuple of preprocessed train dataloaders and validation dataloaders.
        """
        return train_dataloaders, val_dataloaders

    def run(self, base_model: Model, applied_mutators):
        # one-shot strategy doesn't use ``applied_mutators``
        # but get the "mutators" on their own

        _reason = 'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'

        if not isinstance(base_model.python_object, nn.Module):
            raise TypeError('Model is not a nn.Module. ' + _reason)
        py_model: nn.Module = base_model.python_object

        if applied_mutators:
            raise ValueError('Mutator is not empty. ' + _reason)

        if not isinstance(base_model.evaluator, Lightning):
            raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.')

        evaluator_module: LightningModule = base_model.evaluator.module
        evaluator_module.running_mode = 'oneshot'
        evaluator_module.set_model(py_model)

        self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
        evaluator: Lightning = base_model.evaluator
        if evaluator.train_dataloaders is None or evaluator.val_dataloaders is None:
            raise TypeError('Training and validation dataloader are both required to set in evaluator for one-shot strategy.')
        train_loader, val_loader = self.preprocess_dataloader(evaluator.train_dataloaders, evaluator.val_dataloaders)
        evaluator.trainer.fit(self.model, train_loader, val_loader)

    def export_top_models(self, top_k: int = 1) -> list[Any]:
        """The behavior of export top models in strategy depends on the implementation of inner one-shot module."""
        if self.model is None:
            raise RuntimeError('One-shot strategy needs to be run before export.')
        if top_k != 1:
            warnings.warn('One-shot strategy currently only supports exporting top-1 model.', RuntimeWarning)
        return [self.model.export()]


[docs]class DARTS(OneShotStrategy): __doc__ = DartsLightningModule._darts_note.format(module_notes='', module_params='') def __init__(self, **kwargs): super().__init__(DartsLightningModule, **kwargs) def preprocess_dataloader(self, train_dataloaders, val_dataloaders): # By returning a dict, we make a CombinedLoader (in Lightning) return { 'train': train_dataloaders, 'val': val_dataloaders }, None
[docs]class Proxyless(OneShotStrategy): __doc__ = ProxylessLightningModule._proxyless_note.format(module_notes='', module_params='') def __init__(self, **kwargs): super().__init__(ProxylessLightningModule, **kwargs) def preprocess_dataloader(self, train_dataloaders, val_dataloaders): return { 'train': train_dataloaders, 'val': val_dataloaders }, None
[docs]class GumbelDARTS(OneShotStrategy): __doc__ = GumbelDartsLightningModule._gumbel_darts_note.format(module_notes='', module_params='') def __init__(self, **kwargs): super().__init__(GumbelDartsLightningModule, **kwargs) def preprocess_dataloader(self, train_dataloaders, val_dataloaders): return { 'train': train_dataloaders, 'val': val_dataloaders }, None
[docs]class ENAS(OneShotStrategy): __doc__ = EnasLightningModule._enas_note.format(module_notes='', module_params='') def __init__(self, **kwargs): super().__init__(EnasLightningModule, **kwargs) def preprocess_dataloader(self, train_dataloaders, val_dataloaders): # Import locally to avoid import error on legacy PL version from .dataloader import ConcatLoader return ConcatLoader({ 'train': train_dataloaders, 'val': val_dataloaders }), None
[docs]class RandomOneShot(OneShotStrategy): __doc__ = RandomSamplingLightningModule._random_note.format(module_notes='', module_params='') def __init__(self, **kwargs): super().__init__(RandomSamplingLightningModule, **kwargs)