nni.compression.pruning.scheduled_pruner 源代码

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations

from collections import defaultdict
import logging
from typing import Dict, List

import torch

from ..base.compressor import Pruner
from ..base.wrapper import ModuleWrapper
from ..utils import Evaluator, _EVALUATOR_DOCSTRING

from .basic_pruner import LevelPruner, L1NormPruner, L2NormPruner, FPGMPruner
from .slim_pruner import SlimPruner
from .taylor_pruner import TaylorPruner

_logger = logging.getLogger(__name__)


class ScheduledPruner(Pruner):
    def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator | None = None,
                 existed_wrappers: Dict[str, ModuleWrapper] | None = None):
        super().__init__(model, config_list, evaluator, existed_wrappers)
        self.evaluator: Evaluator

        self.sparse_goals: Dict[str, Dict[str, Dict[str, float]]] = defaultdict(dict)
        self._goals_initialized = False
        self._scheduled_keys = ['sparse_ratio', 'sparse_threshold', 'max_sparse_ratio', 'min_sparse_ratio']

    def _init_sparse_goals(self):
        if self._goals_initialized:
            _logger.warning('Sparse goals have already initialized.')
            return
        for module_name, ts in self._target_spaces.items():
            for target_name, target_space in ts.items():
                self.sparse_goals[module_name][target_name] = {}
                for scheduled_key in self._scheduled_keys:
                    if getattr(target_space, scheduled_key) is not None:
                        self.sparse_goals[module_name][target_name][scheduled_key] = getattr(target_space, scheduled_key)
        self._goals_initialized = True

    def update_sparse_goals(self, current_times: int):
        raise NotImplementedError()

    def _update_sparse_goals_by_ratio(self, ratio: float):
        for module_name, tg in self.sparse_goals.items():
            for target_name, target_goals in tg.items():
                for scheduled_key, goal in target_goals.items():
                    setattr(self._target_spaces[module_name][target_name], scheduled_key, goal * ratio)


class _ComboPruner(ScheduledPruner):
    def __init__(self, pruner: Pruner, interval_steps: int, total_times: int, evaluator: Evaluator | None = None):
        assert isinstance(pruner, Pruner)
        assert hasattr(pruner, 'interval_steps') and hasattr(pruner, 'total_times')
        if not isinstance(pruner, (LevelPruner, L1NormPruner, L2NormPruner, FPGMPruner, SlimPruner, TaylorPruner)):
            warning_msg = f'Compatibility not tested with pruner type {pruner.__class__.__name__}.'
            _logger.warning(warning_msg)
        if pruner._is_wrapped:
            pruner.unwrap_model()

        model = pruner.bound_model
        existed_wrappers = pruner._module_wrappers
        if pruner.evaluator is not None and evaluator is not None:
            _logger.warning('Pruner already has evaluator, the new evaluator passed to this function will be ignored.')
        evaluator = pruner.evaluator if pruner.evaluator else evaluator
        assert isinstance(evaluator, Evaluator)

        super().__init__(model=model, config_list=[], evaluator=evaluator, existed_wrappers=existed_wrappers)
        # skip the pruner passed in
        self.fused_compressors.extend(pruner.fused_compressors[1:])
        self._target_spaces = pruner._target_spaces
        self.interval_steps = interval_steps
        self.total_times = total_times
        self.bound_pruner = pruner

        self._init_sparse_goals()
        self._initial_ratio = 0.0

    @classmethod
    def from_compressor(cls, *args, **kwargs):
        raise NotImplementedError(f'{cls.__name__} can not initialized from any compressor.')

    def _initialize_state(self):
        self._update_sparse_goals_by_ratio(self._initial_ratio)
        self.bound_pruner.interval_steps = self.interval_steps  # type: ignore
        self.bound_pruner.total_times = self.total_times  # type: ignore

    def _register_trigger(self, evaluator: Evaluator):
        self._current_step = 0
        self._remaining_times = self.total_times

        def optimizer_task():
            self._current_step += 1
            if self._current_step == self.interval_steps:
                self._remaining_times -= 1
                self.update_sparse_goals(self.total_times - self._remaining_times)
                debug_msg = f'{self.__class__.__name__} generate masks, remaining times {self._remaining_times}'
                _logger.debug(debug_msg)
                if self._remaining_times > 0:
                    self._current_step = 0

        evaluator.patch_optimizer_step(before_step_tasks=[], after_step_tasks=[optimizer_task])

    def _single_compress(self, max_steps: int | None, max_epochs: int | None):
        self._fusion_compress(max_steps, max_epochs)

    def _fuse_preprocess(self, evaluator: Evaluator) -> None:
        self._initialize_state()
        self._register_trigger(evaluator)
        self.bound_pruner._fuse_preprocess(evaluator)

    def _fuse_postprocess(self, evaluator: Evaluator) -> None:
        self.bound_pruner._fuse_postprocess(evaluator)

    def compress(self, max_steps: int | None, max_epochs: int | None):
        if max_steps is not None:
            assert max_steps >= self.total_times * self.interval_steps
        else:
            warn_msg = f'Using epochs number as training duration, ' + \
                'please make sure the total training steps larger than total_times * interval_steps.'
            _logger.warning(warn_msg)
        return super().compress(max_steps, max_epochs)


[文档] class LinearPruner(_ComboPruner): __doc__ = r""" The sparse ratio or sparse threshold in the bound pruner will increase in a linear way from 0. to final:: current_sparse = (1 - initial_ratio) * current_times / total_times * final_sparse If min/max sparse ratio is also set in target setting, they will also synchronous increase in a linear way. Note that this pruner can not be initialized by ``LinearPruner.from_compressor(...)``. Parameters ---------- pruner The bound pruner. interval_steps A integer number, for each ``interval_steps`` training, the sparse goal will be updated. total_times A integer number, how many times to update the sparse goal in total. evaluator {evaluator_docstring} Examples -------- Please refer to :githublink:`examples/compression/pruning/scheduled_pruning.py <examples/compression/pruning/scheduled_pruning.py>`. """.format(evaluator_docstring=_EVALUATOR_DOCSTRING) def update_sparse_goals(self, current_times: int): ratio = (1 - self._initial_ratio) * current_times / self.total_times self._update_sparse_goals_by_ratio(ratio)
[文档] class AGPPruner(_ComboPruner): __doc__ = r""" The sparse ratio or sparse threshold in the bound pruner will increase in a AGP way from 0. to final:: current_sparse = (1 - (1 - self._initial_ratio) * (1 - current_times / self.total_times) ** 3) * final_sparse If min/max sparse ratio is also set in target setting, they will also synchronous increase in a AGP way. Note that this pruner can not be initialized by ``AGPPruner.from_compressor(...)``. Parameters ---------- pruner The bound pruner. interval_steps A integer number, for each ``interval_steps`` training, the sparse goal will be updated. total_times A integer number, how many times to update the sparse goal in total. evaluator {evaluator_docstring} Examples -------- Please refer to :githublink:`examples/compression/pruning/scheduled_pruning.py <examples/compression/pruning/scheduled_pruning.py>`. """.format(evaluator_docstring=_EVALUATOR_DOCSTRING) def update_sparse_goals(self, current_times: int): ratio = 1 - (1 - self._initial_ratio) * (1 - current_times / self.total_times) ** 3 self._update_sparse_goals_by_ratio(ratio)