Source code for nni.compression.pruning.slim_pruner

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from collections import defaultdict
import logging
from typing import Dict, List, Literal, Tuple, overload

import torch
from torch.optim import Adam

from .tools import _METRICS, _MASKS, generate_sparsity, is_active_target
from ..base.compressor import Compressor, Pruner
from ..base.target_space import TargetType
from ..base.wrapper import ModuleWrapper
from ..utils import Evaluator, _EVALUATOR_DOCSTRING

_logger = logging.getLogger(__name__)


SLIM_SCALING_FACTOR_PNAME = '{}_slim_factor'


[docs] class SlimPruner(Pruner): __doc__ = r""" Slim pruner adds sparsity regularization on the scaling factors of batch normalization (BN) layers during training to identify unimportant channels. The channels with small scaling factor values will be pruned. For more details, please refer to `Learning Efficient Convolutional Networks through Network Slimming <https://arxiv.org/abs/1708.06519>`__. Parameters ---------- model Model to be pruned. config_list A list of dict, each dict configure which module need to be pruned, and how to prune. Please refer :doc:`Compression Config Specification </compression/config_list>` for more information. evaluator {evaluator_docstring} training_steps An integer to control steps of training the model and scale factors. Masks will be generated after ``training_steps``. regular_scale ``regular_scale`` controls the scale factors' penalty. Examples -------- Please refer to :githublink:`examples/compression/pruning/slim_pruning.py <examples/compression/pruning/slim_pruning.py>`. """.format(evaluator_docstring=_EVALUATOR_DOCSTRING) @overload def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, training_steps: int, regular_scale: float = 1.): ... @overload def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, training_steps: int, regular_scale: float = 1., existed_wrappers: Dict[str, ModuleWrapper] | None = None): ... def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, training_steps: int, regular_scale: float = 1., existed_wrappers: Dict[str, ModuleWrapper] | None = None): super().__init__(model, config_list, evaluator, existed_wrappers) self.evaluator: Evaluator self.training_steps = training_steps self.regular_scale = regular_scale # trigger masks generation when self._current_step == self.training_steps self._current_step = 0 # `interval_steps` and `total_times` are used by `register_trigger`. # `interval_steps` is the optimize step interval for generating masks. # `total_times` is the total generation times of masks. self.interval_steps = training_steps self.total_times: int | Literal['unlimited'] = 1 self._set_apply_method() @classmethod def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], training_steps: int, regular_scale: float = 1., evaluator: Evaluator | None = None): return super().from_compressor(compressor, new_config_list, training_steps=training_steps, regular_scale=regular_scale, evaluator=evaluator) def _set_apply_method(self): for _, ts in self._target_spaces.items(): for _, target_space in ts.items(): if target_space.apply_method == 'mul': target_space.apply_method = 'slim_mul' else: assert target_space.apply_method == 'slim_mul' def _register_scaling_facotrs(self): self.scaling_factors = defaultdict(dict) for module_name, ts in self._target_spaces.items(): for target_name, target_space in ts.items(): if is_active_target(target_space): # TODO: add input / output if target_space.type is TargetType.PARAMETER: # TODO: here using a shrinked score to save memory, but need to test the speed. scaling_factor = torch.ones_like(target_space.target) # type: ignore if target_space._scaler is not None: scaling_factor = target_space._scaler.shrink(scaling_factor, keepdim=True) target_space._wrapper.register_parameter(SLIM_SCALING_FACTOR_PNAME.format(target_name), torch.nn.Parameter(scaling_factor)) scaling_factor = target_space._get_wrapper_attr(SLIM_SCALING_FACTOR_PNAME.format(target_name)) self.scaling_factors[module_name][target_name] = scaling_factor else: raise NotImplementedError() def _register_factors_optimization(self, evaluator: Evaluator): scaling_factors = [] for _, target_scaling_factor in self.scaling_factors.items(): for _, scaling_factor in target_scaling_factor.items(): scaling_factors.append(scaling_factor) if not scaling_factors: return params = [{"params": scaling_factors}] optimizer = Adam(params, 1e-2) evaluator.patch_optimizer_step(before_step_tasks=[optimizer.step], after_step_tasks=[optimizer.zero_grad]) def _patch_loss(self, evaluator: Evaluator): def loss_patch(original_loss, batch): reg_loss = torch.tensor(0., device=original_loss.device) count = 0 for _, target_scaling_factor in self.scaling_factors.items(): for _, scaling_factor in target_scaling_factor.items(): reg_loss = reg_loss + scaling_factor.norm(p=1) # type: ignore count += 1 if count > 0: reg_loss = self.regular_scale * reg_loss / count return original_loss + reg_loss evaluator.patch_loss(loss_patch) def _collect_data(self) -> Dict[str, Dict[str, torch.Tensor]]: data = defaultdict(dict) for module_name, ts in self._target_spaces.items(): for target_name, target_space in ts.items(): scaling_factor: torch.Tensor = \ getattr(target_space._wrapper, SLIM_SCALING_FACTOR_PNAME.format(target_name), None) # type: ignore if scaling_factor is not None: data[module_name][target_name] = scaling_factor.clone().detach() return data def _calculate_metrics(self, data: Dict[str, Dict[str, torch.Tensor]]) -> _METRICS: return {k: {p: q.abs() for p, q in v.items()} for k, v in data.items()} def _generate_sparsity(self, metrics: _METRICS) -> _MASKS: return generate_sparsity(metrics, self._target_spaces) def _register_trigger(self, evaluator: Evaluator): assert self.interval_steps >= self.training_steps or self.interval_steps < 0 self._remaining_times = self.total_times def optimizer_task(): self._current_step += 1 if self._current_step == self.training_steps: masks = self.generate_masks() self.update_masks(masks) if isinstance(self._remaining_times, int): self._remaining_times -= 1 debug_msg = f'{self.__class__.__name__} generate masks, remaining times {self._remaining_times}' _logger.debug(debug_msg) if self._current_step == self.interval_steps and \ (self._remaining_times == 'unlimited' or self._remaining_times > 0): # type: ignore self._current_step = 0 evaluator.patch_optimizer_step(before_step_tasks=[], after_step_tasks=[optimizer_task]) def _single_compress(self, max_steps: int | None, max_epochs: int | None): assert max_steps is None and max_epochs is None self._fusion_compress(self.training_steps, None) def _fuse_preprocess(self, evaluator: Evaluator) -> None: self._register_scaling_facotrs() self._register_factors_optimization(evaluator) self._patch_loss(evaluator) self._register_trigger(evaluator) def _fuse_postprocess(self, evaluator: Evaluator) -> None: pass @overload def compress(self) -> Tuple[torch.nn.Module, _MASKS]: ... @overload def compress(self, max_steps: int | None, max_epochs: int | None) -> Tuple[torch.nn.Module, _MASKS]: ... def compress(self, max_steps: int | None = None, max_epochs: int | None = None): return super().compress(max_steps, max_epochs)