Source code for nni.compression.pruning.movement_pruner

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from collections import defaultdict
import logging
from typing import Dict, List, overload

import torch
from torch.optim import Adam

from .scheduled_pruner import ScheduledPruner
from .tools import is_active_target, generate_sparsity
from ..base.compressor import Compressor
from ..base.target_space import TargetType
from ..base.wrapper import ModuleWrapper
from ..utils import Evaluator, _EVALUATOR_DOCSTRING

MOVEMENT_SCORE_PNAME = '{}_mvp_score'
_logger = logging.getLogger(__name__)

[docs] class MovementPruner(ScheduledPruner): __doc__ = r""" Movement pruner is an implementation of movement pruning. This is a "fine-pruning" algorithm, which means the masks may change during each fine-tuning step. Each weight element will be scored by the opposite of the sum of the product of weight and its gradient during each step. This means the weight elements moving towards zero will accumulate negative scores, the weight elements moving away from zero will accumulate positive scores. The weight elements with low scores will be masked during inference. The following figure from the paper shows the weight pruning by movement pruning. .. image:: ../../../img/movement_pruning.png :target: ../../../img/movement_pruning.png :alt: For more details, please refer to `Movement Pruning: Adaptive Sparsity by Fine-Tuning <>`__. Parameters ---------- model Model to be pruned. config_list A list of dict, each dict configure which module need to be pruned, and how to prune. Please refer :doc:`Compression Config Specification </compression/config_list>` for more information. evaluator {evaluator_docstring} warmup_step The total `optimizer.step()` number before start pruning for warm up. Make sure ``warmup_step`` is smaller than ``cooldown_begin_step``. cooldown_begin_step The number of steps at which sparsity stops growing, note that the sparsity stop growing doesn't mean masks not changed. The sparse ratio or sparse threshold after each `optimizer.step()` is:: final_sparse * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3) regular_scale A scale factor used to control the movement score regular loss. This factor only works on pruning target controlled by ``sparse_threshold``, the pruning target controlled by ``sparse_ratio`` will not be regularized. Examples -------- Please refer to :githublink:`examples/tutorials/ <examples/tutorials/>`. """.format(evaluator_docstring=_EVALUATOR_DOCSTRING) @overload def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, warmup_step: int, cooldown_begin_step: int, regular_scale: float = 1.): ... @overload def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, warmup_step: int, cooldown_begin_step: int, regular_scale: float = 1., existed_wrappers: Dict[str, ModuleWrapper] | None = None): ... def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, warmup_step: int, cooldown_begin_step: int, regular_scale: float = 1., existed_wrappers: Dict[str, ModuleWrapper] | None = None): super().__init__(model, config_list, evaluator, existed_wrappers) self.evaluator: Evaluator assert 0 <= warmup_step < cooldown_begin_step self.warmup_step = warmup_step self.cooldown_begin_step = cooldown_begin_step self.regular_scale = regular_scale self._init_sparse_goals() self._set_apply_method() self.interval_steps = 1 self.total_times = (self.cooldown_begin_step - self.warmup_step) // self.interval_steps self._remaining_times: int self.scores: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict) @classmethod def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], warmup_step: int, cooldown_begin_step: int, regular_scale: float = 1., evaluator: Evaluator | None = None): return super().from_compressor(compressor, new_config_list, warmup_step=warmup_step, cooldown_begin_step=cooldown_begin_step, regular_scale=regular_scale, evaluator=evaluator) def _set_apply_method(self): for _, ts in self._target_spaces.items(): for _, target_space in ts.items(): if target_space.apply_method == 'mul': target_space.apply_method = 'movement_mul' if target_space.apply_method == 'add': target_space.apply_method = 'movement_add' def _register_movement_scores(self): for module_name, ts in self._target_spaces.items(): for target_name, target_space in ts.items(): if is_active_target(target_space): # TODO: add input / output if target_space.type is TargetType.PARAMETER: # TODO: here using a shrinked score to save memory, but need to test the speed. score_val = torch.zeros_like( # type: ignore if target_space._scaler is not None: score_val = target_space._scaler.shrink(score_val, keepdim=True) target_space._wrapper.register_parameter(MOVEMENT_SCORE_PNAME.format(target_name), torch.nn.Parameter(score_val)) score = target_space._get_wrapper_attr(MOVEMENT_SCORE_PNAME.format(target_name)) self.scores[module_name][target_name] = score else: raise NotImplementedError() def _register_scores_optimization(self, evaluator: Evaluator): scores = [] for _, target_scores in self.scores.items(): for _, score in target_scores.items(): scores.append(score) if not scores: return params = [{"params": scores}] optimizer = Adam(params, 1e-2) def optimizer_task(): optimizer.step() optimizer.zero_grad() evaluator.patch_optimizer_step(before_step_tasks=[optimizer_task], after_step_tasks=[]) def _patch_loss(self, evaluator: Evaluator): def loss_patch(original_loss, batch): reg_loss = 0. count = 0 for module_name, target_scores in self.scores.items(): for target_name, score in target_scores.items(): target_space = self._target_spaces[module_name][target_name] if target_space.sparse_threshold is not None: reg_loss += torch.norm(score.sigmoid(), p=1) / score.numel() # type: ignore count += 1 ratio = max(0., min(1., 1 - (self._remaining_times / self.total_times) ** 3)) if count > 0: reg_loss = self.regular_scale * ratio * reg_loss / count return original_loss + reg_loss evaluator.patch_loss(loss_patch) def _register_trigger(self, evaluator: Evaluator): self._current_step = 0 self._iterial_step = 0 self._remaining_times = self.total_times def optimizer_task(): self._current_step += 1 if self.warmup_step < self._current_step <= self.cooldown_begin_step: self._iterial_step += 1 if self._iterial_step == self.interval_steps: self._remaining_times -= 1 self.update_sparse_goals(self.total_times - self._remaining_times) debug_msg = f'{self.__class__.__name__} generate masks, remaining times {self._remaining_times}' _logger.debug(debug_msg) if self._remaining_times > 0: self._iterial_step = 0 if self.warmup_step < self._current_step: self.update_masks(self.generate_masks()) evaluator.patch_optimizer_step(before_step_tasks=[], after_step_tasks=[optimizer_task]) def update_sparse_goals(self, current_times: int): ratio = max(0., min(1., 1 - (1 - current_times / self.total_times) ** 3)) self._update_sparse_goals_by_ratio(ratio) def _collect_data(self) -> Dict[str, Dict[str, torch.Tensor]]: data = defaultdict(dict) for module_name, ts in self._target_spaces.items(): for target_name, target_space in ts.items(): score: torch.Tensor = getattr(target_space._wrapper, MOVEMENT_SCORE_PNAME.format(target_name), None) # type: ignore if score is not None: data[module_name][target_name] = score.clone().detach() return data def _calculate_metrics(self, data: Dict[str, Dict[str, torch.Tensor]]) -> Dict[str, Dict[str, torch.Tensor]]: metrics = defaultdict(dict) for module_name, td in data.items(): for target_name, target_data in td.items(): if self._target_spaces[module_name][target_name].sparse_threshold is not None: metrics[module_name][target_name] = target_data.sigmoid() else: metrics[module_name][target_name] = target_data return metrics def _generate_sparsity(self, metrics: Dict[str, Dict[str, torch.Tensor]]) -> Dict[str, Dict[str, torch.Tensor]]: return generate_sparsity(metrics=metrics, target_spaces=self._target_spaces) def _single_compress(self, max_steps: int | None, max_epochs: int | None): self._fusion_compress(max_steps, max_epochs) def _fuse_preprocess(self, evaluator: Evaluator): self._update_sparse_goals_by_ratio(0.) self._register_movement_scores() self._patch_loss(evaluator) self._register_scores_optimization(evaluator) self._register_trigger(evaluator) def _fuse_postprocess(self, evaluator: Evaluator): pass def compress(self, max_steps: int | None, max_epochs: int | None): if max_steps is not None: assert max_steps >= self.cooldown_begin_step else: warn_msg = \ f'Using epochs number as training duration, please make sure the total training steps larger than `cooldown_begin_step`.' _logger.warning(warn_msg) return super().compress(max_steps, max_epochs)