Source code for nni.algorithms.compression.pytorch.pruning.one_shot_pruner

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from schema import And, Optional

from nni.compression.pytorch.utils.config_validation import PrunerSchema
from .dependency_aware_pruner import DependencyAwarePruner

__all__ = ['LevelPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


[docs]class OneshotPruner(DependencyAwarePruner): """ Prune model to an exact pruning level for one time. """ def __init__(self, model, config_list, pruning_algorithm='level', dependency_aware=False, dummy_input=None, **algo_kwargs): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list List on pruning configs pruning_algorithm: str algorithms being used to prune model dependency_aware: bool If prune the model in a dependency-aware way. dummy_input : torch.Tensor The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model. algo_kwargs: dict Additional parameters passed to pruning algorithm masker class """ super().__init__(model, config_list, None, pruning_algorithm, dependency_aware, dummy_input, **algo_kwargs)
[docs] def validate_config(self, model, config_list): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list List on pruning configs """ schema = PrunerSchema([{ Optional('sparsity'): And(float, lambda n: 0 < n < 1), Optional('op_types'): [str], Optional('op_names'): [str], Optional('exclude'): bool }], model, logger) schema.validate(config_list)
[docs]class LevelPruner(OneshotPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Operation types to prune. """ def __init__(self, model, config_list): super().__init__(model, config_list, pruning_algorithm='level') def _supported_dependency_aware(self): return False
[docs]class L1FilterPruner(OneshotPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only Conv2d is supported in L1FilterPruner. dependency_aware: bool If prune the model in a dependency-aware way. If it is `True`, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if this flag is set True , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers. dummy_input : torch.Tensor The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model. """ def __init__(self, model, config_list, dependency_aware=False, dummy_input=None): super().__init__(model, config_list, pruning_algorithm='l1', dependency_aware=dependency_aware, dummy_input=dummy_input) def _supported_dependency_aware(self): return True
[docs]class L2FilterPruner(OneshotPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only Conv2d is supported in L2FilterPruner. dependency_aware: bool If prune the model in a dependency-aware way. If it is `True`, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if this flag is set True , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers. dummy_input : torch.Tensor The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model. """ def __init__(self, model, config_list, dependency_aware=False, dummy_input=None): super().__init__(model, config_list, pruning_algorithm='l2', dependency_aware=dependency_aware, dummy_input=dummy_input) def _supported_dependency_aware(self): return True
[docs]class FPGMPruner(OneshotPruner): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - sparsity : This is to specify the sparsity operations to be compressed to. - op_types : Only Conv2d is supported in FPGM Pruner. dependency_aware: bool If prune the model in a dependency-aware way. If it is `True`, this pruner will prune the model according to the l2-norm of weights and the channel-dependency or group-dependency of the model. In this way, the pruner will force the conv layers that have dependencies to prune the same channels, so the speedup module can better harvest the speed benefit from the pruned model. Note that, if this flag is set True , the dummy_input cannot be None, because the pruner needs a dummy input to trace the dependency between the conv layers. dummy_input : torch.Tensor The dummy input to analyze the topology constraints. Note that, the dummy_input should on the same device with the model. """ def __init__(self, model, config_list, dependency_aware=False, dummy_input=None): super().__init__(model, config_list, pruning_algorithm='fpgm', dependency_aware=dependency_aware, dummy_input=dummy_input) def _supported_dependency_aware(self): return True