Source code for nni.compression.distillation.basic_distiller

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from collections import defaultdict
import logging
from typing import Any, Callable, Dict, List, overload

import torch
import torch.nn.functional as F
from torch.utils._pytree import tree_map

from ..base.compressor import Compressor, Distiller, _DISTILLATION_TARGET_SPACES
from ..base.wrapper import ModuleWrapper, register_wrappers
from ..utils import Evaluator, _EVALUATOR_DOCSTRING

_logger = logging.getLogger(__name__)


class TeacherModelBasedDistiller(Distiller):
    __doc__ = r"""The base class that the distiller need a teacher model.

    Parameters
    ----------
    model
        The student model to be distilled.
    config_list
        A list of dict, each dict configure which module need to be distilled, and how to distill.
        Please refer :doc:`Compression Config Specification </compression/config_list>` for more information.
    evaluator
        {evaluator_docstring}
    teacher_model
        The distillation teacher model.
    teacher_predict
        A callable function with two inputs (batch, model).

        Example::

            def teacher_predict(batch, teacher_model):
                return teacher_model(**batch)

    origin_loss_lambda
        A scaling factor to control the original loss scale.
    """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)

    @overload
    def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator,
                 teacher_model: torch.nn.Module, teacher_predict: Callable[[Any, torch.nn.Module], torch.Tensor],
                 origin_loss_lambda: float = 1.):
        ...

    @overload
    def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator,
                 teacher_model: torch.nn.Module, teacher_predict: Callable[[Any, torch.nn.Module], torch.Tensor],
                 origin_loss_lambda: float = 1., existed_wrappers: Dict[str, ModuleWrapper] | None = None):
        ...

    def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator,
                 teacher_model: torch.nn.Module, teacher_predict: Callable[[Any, torch.nn.Module], torch.Tensor],
                 origin_loss_lambda: float = 1., existed_wrappers: Dict[str, ModuleWrapper] | None = None):
        assert model is not teacher_model, 'Student model and teacher model should not be the same.'
        super().__init__(model=model, config_list=config_list, evaluator=evaluator,
                         existed_wrappers=existed_wrappers)

        self.teacher_model = teacher_model
        self.teacher_predict = teacher_predict
        self.origin_loss_lambda = origin_loss_lambda
        self._set_default_link()
        self._set_default_lambda()
        self._teacher_module_wrappers, target_spaces = self._register_teacher_wrappers()
        self._teacher_target_spaces: _DISTILLATION_TARGET_SPACES = target_spaces  # type: ignore
        self._teacher_is_wrapped = False
        self.wrap_teacher_model()

    @classmethod
    def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], teacher_model: torch.nn.Module,
                        teacher_predict: Callable[[Any, torch.nn.Module], torch.Tensor], origin_loss_lambda: float = 1.,
                        evaluator: Evaluator | None = None):
        return super().from_compressor(compressor, new_config_list, evaluator=evaluator, teacher_model=teacher_model,
                                       teacher_predict=teacher_predict, origin_loss_lambda=origin_loss_lambda)

    def _set_default_link(self):
        for module_name, ts in self._target_spaces.items():
            for _, target_space in ts.items():
                link = target_space.link if target_space.link is not None else 'auto'
                link = module_name if link == 'auto' else link
                link = [link] if isinstance(link, str) else link
                # assert all(l in self._teacher_target_spaces for l in link), '`link` should be a module name in teacher model.'
                target_space.link = link

    def _set_default_lambda(self):
        for _, ts in self._target_spaces.items():
            for _, target_space in ts.items():
                target_space.lambda_ = target_space.lambda_ if target_space.lambda_ is not None else 1.

    def _register_teacher_wrappers(self):
        link2targets = defaultdict(set)
        teacher_config_list = []
        for _, ts in self._target_spaces.items():
            for target_name, target_space in ts.items():
                for link in target_space.link:
                    link2targets[link].add(target_name)
        teacher_config_list = [{
            'op_names': [link],
            'target_names': list(target_names)
        } for link, target_names in link2targets.items()]
        return register_wrappers(self.teacher_model, teacher_config_list, mode=self.mode)

    def wrap_teacher_model(self):
        """
        Traverse all teacher wrappers and execute ModuleWrapper.wrap()
        """
        if self._teacher_is_wrapped is True:
            warn_msg = 'The bound model has been wrapped, no need to wrap again.'
            _logger.warning(warn_msg)
        for _, wrapper in self._teacher_module_wrappers.items():
            wrapper.wrap()
        self._teacher_is_wrapped = True

    def unwrap_teacher_model(self):
        """
        Traverse all teacher wrappers and execute ModuleWrapper.unwrap()
        """
        if self._teacher_is_wrapped is False:
            warn_msg = 'The bound model is not wrapped, can not unwrap it.'
            _logger.warning(warn_msg)
        for _, wrapper in self._teacher_module_wrappers.items():
            wrapper.unwrap()
        self._teacher_is_wrapped = False

    def _register_loss_patch(self, evaluator: Evaluator):
        def loss_patch(original_loss, batch):
            with torch.no_grad():
                self.teacher_predict(batch, self.teacher_model)
            return self.origin_loss_lambda * original_loss + self.compute_distill_loss()

        evaluator.patch_loss(loss_patch)

    def compute_distill_loss(self):
        raise NotImplementedError()

    def _single_compress(self, max_steps: int | None, max_epochs: int | None):
        self._fusion_compress(max_steps, max_epochs)

    def _fuse_preprocess(self, evaluator: Evaluator):
        self._register_loss_patch(evaluator)

    def _fuse_postprocess(self, evaluator: Evaluator):
        pass


[docs] class DynamicLayerwiseDistiller(TeacherModelBasedDistiller): __doc__ = r""" Each student model distillation target (i.e., the output of a layer in the student model) will link to a list of teacher model distillation targets in this distiller. During distillation, a student target will compute a list of distillation losses with each of its linked teacher targets, then choose the minimum loss in the loss list as current student target distillation loss. The final distillation loss is the sum of each student target distillation loss multiplied by lambda. The final training loss is original loss multiplied by origin_loss_lambda add final distillation loss. Parameters ---------- model The student model to be distilled. config_list Config list to configure how to distill. Common keys please refer :doc:`Compression Config Specification </compression/config_list>`. Specific keys: * 'lambda': By default, 1. This is a scaling factor to control the loss scale, the final loss used during training is ``(origin_loss_lambda * origin_loss + sum(lambda_i * distill_loss_i))``. Here ``i`` represents the ``i-th`` distillation target. The higher the value of lambda, the greater the contribution of the corresponding distillation target to the loss. * 'link': By default, 'auto'. 'auto' or a teacher module name or a list of teacher module names, the module name(s) of teacher module(s) will align with student module(s) configured in this config. If 'auto' is set, will use student module name as the link, usually requires the teacher model and the student model to be isomorphic. * 'apply_method': By default, 'mse'. 'mse' and 'kl' are supported right now. 'mse' means the MSE loss, usually used to distill hidden states. 'kl' means the KL loss, usually used to distill logits. evaluator {evaluator_docstring} teacher_model The distillation teacher model. teacher_predict A callable function with two inputs (batch, model). Example:: def teacher_predict(batch, teacher_model): return teacher_model(**batch) origin_loss_lambda A scaling factor to control the original loss scale. """.format(evaluator_docstring=_EVALUATOR_DOCSTRING) def compute_distill_loss(self): distill_loss = 0 for _, ts in self._target_spaces.items(): for target_name, target_space in ts.items(): stu_hs = target_space.hidden_state loss_list = [] for link in target_space.link: teacher_target_space = self._teacher_target_spaces[link][target_name] tea_hs = teacher_target_space.hidden_state if stu_hs is not None and tea_hs is not None: tea_hs = tea_hs.to(stu_hs.device) if target_space.apply_method == 'mse': loss_list.append(target_space.lambda_ * F.mse_loss(stu_hs, tea_hs)) elif target_space.apply_method == 'kl': loss_list.append(target_space.lambda_ * \ F.kl_div((stu_hs / 2).log_softmax(dim=-1), (tea_hs / 2).softmax(dim=-1), reduction='batchmean') * (2 ** 2)) if loss_list: distill_loss = distill_loss + min(loss_list) for _, ts in self._target_spaces.items(): for _, target_space in ts.items(): target_space.clean() for _, ts in self._teacher_target_spaces.items(): for _, target_space in ts.items(): target_space.clean() return distill_loss
[docs] class Adaptive1dLayerwiseDistiller(TeacherModelBasedDistiller): __doc__ = r""" This distiller will adaptively align the last dimension between student distillation target and teacher distillation target by adding a trainable ``torch.nn.Linear`` between them. (If the last dimensions between student and teacher have already aligned, won't add a new linear layer.) Note that this distiller need call ``Adaptive1dLayerwiseDistiller.track_forward(...)`` first to get the shape of each distillation target to initialize the linear layer before call ``Adaptive1dLayerwiseDistiller.compress(...)``. Parameters ---------- model The student model to be distilled. config_list Config list to configure how to distill. Common keys please refer :doc:`Compression Config Specification </compression/config_list>`. Specific keys: * 'lambda': By default, 1. This is a scaling factor to control the loss scale, the final loss used during training is ``(origin_loss_lambda * origin_loss + sum(lambda_i * distill_loss_i))``. Here ``i`` represents the ``i-th`` distillation target. The higher the value of lambda, the greater the contribution of the corresponding distillation target to the loss. * 'link': By default, 'auto'. 'auto' or a teacher module name or a list of teacher module names, the module name(s) of teacher module(s) will align with student module(s) configured in this config. If 'auto' is set, will use student module name as the link, usually requires the teacher model and the student model to be isomorphic. * 'apply_method': By default, 'mse'. 'mse' and 'kl' are supported right now. 'mse' means the MSE loss, usually used to distill hidden states. 'kl' means the KL loss, usually used to distill logits. evaluator {evaluator_docstring} teacher_model The distillation teacher model. teacher_predict A callable function with two inputs (batch, model). Example:: def teacher_predict(batch, teacher_model): return teacher_model(**batch) origin_loss_lambda A scaling factor to control the original loss scale. """.format(evaluator_docstring=_EVALUATOR_DOCSTRING) def track_forward(self, *args, **kwargs): super().track_forward(*args, **kwargs) with torch.no_grad(): model_device = next(iter(self.teacher_model.parameters())).device args = tree_map(lambda x: x.to(model_device) if isinstance(x, torch.Tensor) else x, args) kwargs = tree_map(lambda x: x.to(model_device) if isinstance(x, torch.Tensor) else x, kwargs) self.teacher_model(*args, **kwargs) def _register_trans_linear(self): self.trans_linears = defaultdict(dict) for module_name, ts in self._target_spaces.items(): for target_name, target_space in ts.items(): # For performance reasons only one link is supported... assert isinstance(target_space.link, str) or len(target_space.link) == 1, \ f'only support set one link for target in {self.__class__.__name__}' stu_hs = target_space.hidden_state link = target_space.link if isinstance(target_space.link, str) else target_space.link[0] tea_hs = self._teacher_target_spaces[link][target_name].hidden_state assert stu_hs is not None and tea_hs is not None, \ 'Please run AdaptiveShapeLayerwiseDistiller.track_forward(...) first before compress.' if stu_hs.shape[-1] == tea_hs.shape[-1]: self.trans_linears[module_name][target_name] = None else: self.trans_linears[module_name][target_name] = torch.nn.Linear(stu_hs.shape[-1], tea_hs.shape[-1]).to(stu_hs.device) def _register_linears_optimization(self, evaluator: Evaluator): linear_params = {} for module_name, linears in self.trans_linears.items(): for _, linear in linears.items(): if linear is not None: linear_params[module_name] = list(linear.parameters()) if not linear_params: return evaluator.patch_optim_param_group(linear_params) def compute_distill_loss(self): distill_loss = 0 for module_name, ts in self._target_spaces.items(): for target_name, target_space in ts.items(): stu_hs = target_space.hidden_state link = target_space.link if isinstance(target_space.link, str) else target_space.link[0] tea_hs = self._teacher_target_spaces[link][target_name].hidden_state if stu_hs is not None and tea_hs is not None: if self.trans_linears[module_name][target_name] is not None: self.trans_linears[module_name][target_name].to(stu_hs.device) stu_hs = self.trans_linears[module_name][target_name](stu_hs) tea_hs = tea_hs.to(stu_hs.device) if target_space.apply_method == 'mse': distill_loss += target_space.lambda_ * F.mse_loss(stu_hs, tea_hs) elif target_space.apply_method == 'kl': distill_loss += target_space.lambda_ * \ F.kl_div((stu_hs / 2).log_softmax(dim=-1), (tea_hs / 2).softmax(dim=-1), reduction='batchmean') * (2 ** 2) for _, ts in self._target_spaces.items(): for _, target_space in ts.items(): target_space.clean() for _, ts in self._teacher_target_spaces.items(): for _, target_space in ts.items(): target_space.clean() return distill_loss def _fuse_preprocess(self, evaluator: Evaluator): self._register_trans_linear() self._register_linears_optimization(evaluator) self._register_loss_patch(evaluator)