# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import copy
import torch
from schema import And, Optional
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from .constants import MASKER_DICT
from .dependency_aware_pruner import DependencyAwarePruner
__all__ = ['AGPPruner', 'ADMMPruner', 'SlimPruner', 'TaylorFOWeightFilterPruner', 'ActivationAPoZRankFilterPruner',
'ActivationMeanRankFilterPruner']
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
[docs]class IterativePruner(DependencyAwarePruner):
"""
Prune model during the training process.
"""
def __init__(self, model, config_list, optimizer=None, pruning_algorithm='slim', trainer=None, criterion=None,
num_iterations=20, epochs_per_iteration=5, dependency_aware=False, dummy_input=None, **algo_kwargs):
"""
Parameters
----------
model: torch.nn.Module
Model to be pruned
config_list: list
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
pruning_algorithm: str
algorithms being used to prune model
trainer: function
Function used to train the model.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion: function
Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
num_iterations: int
Total number of iterations in pruning process. We will calculate mask at the end of an iteration.
epochs_per_iteration: Union[int, list]
The number of training epochs for each iteration. `int` represents the same value for each iteration.
`list` represents the specific value for each iteration.
dependency_aware: bool
If prune the model in a dependency-aware way.
dummy_input: torch.Tensor
The dummy input to analyze the topology constraints. Note that,
the dummy_input should on the same device with the model.
algo_kwargs: dict
Additional parameters passed to pruning algorithm masker class
"""
super().__init__(model, config_list, optimizer, pruning_algorithm, dependency_aware, dummy_input, **algo_kwargs)
if isinstance(epochs_per_iteration, list):
assert len(epochs_per_iteration) == num_iterations, 'num_iterations should equal to the length of epochs_per_iteration'
self.epochs_per_iteration = epochs_per_iteration
else:
assert num_iterations > 0, 'num_iterations should >= 1'
self.epochs_per_iteration = [epochs_per_iteration] * num_iterations
self._validate_iteration_params()
self._trainer = trainer
self._criterion = criterion
def _fresh_calculated(self):
for wrapper in self.get_modules_wrapper():
wrapper.if_calculated = False
def _validate_iteration_params(self):
assert all(num >= 0 for num in self.epochs_per_iteration), 'all epoch number need >= 0'
[docs] def compress(self):
training = self.bound_model.training
self.bound_model.train()
for _, epochs_num in enumerate(self.epochs_per_iteration):
self._fresh_calculated()
for epoch in range(epochs_num):
self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch)
# NOTE: workaround for statistics_batch_num bigger than max batch number in one epoch, need refactor
if hasattr(self.masker, 'statistics_batch_num') and hasattr(self, 'iterations'):
if self.iterations < self.masker.statistics_batch_num: # pylint: disable=access-member-before-definition
self.iterations = self.masker.statistics_batch_num
self.update_mask()
self.bound_model.train(training)
return self.bound_model
[docs]class AGPPruner(IterativePruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : listlist
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : See supported type in your specific pruning algorithm.
optimizer: torch.optim.Optimizer
Optimizer used to train model.
trainer: function
Function to train the model
criterion: function
Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
num_iterations: int
Total number of iterations in pruning process. We will calculate mask at the end of an iteration.
epochs_per_iteration: int
The number of training epochs for each iteration.
pruning_algorithm: str
Algorithms being used to prune model,
choose from `['level', 'slim', 'l1', 'l2', 'fpgm', 'taylorfo', 'apoz', 'mean_activation']`, by default `level`
"""
def __init__(self, model, config_list, optimizer, trainer, criterion,
num_iterations=10, epochs_per_iteration=1, pruning_algorithm='level'):
super().__init__(model, config_list, optimizer=optimizer, trainer=trainer, criterion=criterion,
num_iterations=num_iterations, epochs_per_iteration=epochs_per_iteration)
assert isinstance(optimizer, torch.optim.Optimizer), "AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self.masker = MASKER_DICT[pruning_algorithm](model, self)
self.now_epoch = 0
self.freq = epochs_per_iteration
self.end_epoch = epochs_per_iteration * num_iterations
self.set_wrappers_attribute("if_calculated", False)
[docs] def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 <= n <= 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
def _supported_dependency_aware(self):
return False
[docs] def calc_mask(self, wrapper, wrapper_idx=None):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
wrapper : Module
the layer to instrument the compression operation
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict | None
Dictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
"""
config = wrapper.config
if wrapper.if_calculated:
return None
if not self.now_epoch % self.freq == 0:
return None
target_sparsity = self.compute_target_sparsity(config)
new_mask = self.masker.calc_mask(sparsity=target_sparsity, wrapper=wrapper, wrapper_idx=wrapper_idx)
if new_mask is not None:
wrapper.if_calculated = True
return new_mask
[docs] def compute_target_sparsity(self, config):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
initial_sparsity = 0
self.target_sparsity = final_sparsity = config.get('sparsity', 0)
if initial_sparsity >= final_sparsity:
logger.warning('your initial_sparsity >= final_sparsity')
return final_sparsity
if self.end_epoch == 1 or self.end_epoch <= self.now_epoch:
return final_sparsity
span = ((self.end_epoch - 1) // self.freq) * self.freq
assert span > 0
self.target_sparsity = (final_sparsity + (initial_sparsity - final_sparsity) * (1.0 - (self.now_epoch / span)) ** 3)
return self.target_sparsity
[docs] def update_epoch(self, epoch):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if epoch > 0:
self.now_epoch = epoch
for wrapper in self.get_modules_wrapper():
wrapper.if_calculated = False
# TODO: need refactor
[docs] def compress(self):
training = self.bound_model.training
self.bound_model.train()
for epoch in range(self.end_epoch):
self.update_epoch(epoch)
self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch)
self.update_mask()
logger.info(f'sparsity is {self.target_sparsity:.2f} at epoch {epoch}')
self.get_pruned_weights()
self.bound_model.train(training)
return self.bound_model
[docs]class ADMMPruner(IterativePruner):
"""
A Pytorch implementation of ADMM Pruner algorithm.
Parameters
----------
model : torch.nn.Module
Model to be pruned.
config_list : list
List on pruning configs.
trainer : function
Function used for the first subproblem.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion: function
Function used to calculate the loss between the target and the output. By default, we use CrossEntropyLoss in ADMMPruner.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
num_iterations: int
Total number of iterations in pruning process. We will calculate mask after we finish all iterations in ADMMPruner.
epochs_per_iteration: int
Training epochs of the first subproblem.
row : float
Penalty parameters for ADMM training.
base_algo : str
Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among
the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune.
"""
def __init__(self, model, config_list, trainer, criterion=torch.nn.CrossEntropyLoss(),
num_iterations=30, epochs_per_iteration=5, row=1e-4, base_algo='l1'):
self._base_algo = base_algo
super().__init__(model, config_list)
self._trainer = trainer
self.optimizer = torch.optim.Adam(
self.bound_model.parameters(), lr=1e-3, weight_decay=5e-5)
self._criterion = criterion
self._num_iterations = num_iterations
self._training_epochs = epochs_per_iteration
self._row = row
self.set_wrappers_attribute("if_calculated", False)
self.masker = MASKER_DICT[self._base_algo](self.bound_model, self)
self.patch_optimizer_before(self._callback)
[docs] def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
if self._base_algo == 'level':
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
def _supported_dependency_aware(self):
return False
def _projection(self, weight, sparsity, wrapper):
'''
Return the Euclidean projection of the weight matrix according to the pruning mode.
Parameters
----------
weight : tensor
original matrix
sparsity : float
the ratio of parameters which need to be set to zero
wrapper: PrunerModuleWrapper
layer wrapper of this layer
Returns
-------
tensor
the projected matrix
'''
wrapper_copy = copy.deepcopy(wrapper)
wrapper_copy.module.weight.data = weight
return weight.data.mul(self.masker.calc_mask(sparsity, wrapper_copy)['weight_mask'])
def _callback(self):
# callback function to do additonal optimization, refer to the deriatives of Formula (7)
for i, wrapper in enumerate(self.get_modules_wrapper()):
wrapper.module.weight.data -= self._row * \
(wrapper.module.weight.data - self.Z[i] + self.U[i])
[docs] def compress(self):
"""
Compress the model with ADMM.
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
logger.info('Starting ADMM Compression...')
# initiaze Z, U
# Z_i^0 = W_i^0
# U_i^0 = 0
self.Z = []
self.U = []
for wrapper in self.get_modules_wrapper():
z = wrapper.module.weight.data
self.Z.append(z)
self.U.append(torch.zeros_like(z))
# Loss = cross_entropy + l2 regulization + \Sum_{i=1}^N \row_i ||W_i - Z_i^k + U_i^k||^2
# optimization iteration
for k in range(self._num_iterations):
logger.info('ADMM iteration : %d', k)
# step 1: optimize W with AdamOptimizer
for epoch in range(self._training_epochs):
self._trainer(self.bound_model, optimizer=self.optimizer, criterion=self._criterion, epoch=epoch)
# step 2: update Z, U
# Z_i^{k+1} = projection(W_i^{k+1} + U_i^k)
# U_i^{k+1} = U^k + W_i^{k+1} - Z_i^{k+1}
for i, wrapper in enumerate(self.get_modules_wrapper()):
z = wrapper.module.weight.data + self.U[i]
self.Z[i] = self._projection(z, wrapper.config['sparsity'], wrapper)
torch.cuda.empty_cache()
self.U[i] = self.U[i] + wrapper.module.weight.data - self.Z[i]
# apply prune
self.update_mask()
logger.info('Compression finished.')
return self.bound_model
[docs]class SlimPruner(IterativePruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Only BatchNorm2d is supported in Slim Pruner.
optimizer : torch.optim.Optimizer
Optimizer used to train model
trainer : function
Function used to sparsify BatchNorm2d scaling factors.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function
Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
sparsifying_training_epochs: int
The number of channel sparsity regularization training epochs before pruning.
scale : float
Penalty parameters for sparsification.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_epochs=10, scale=0.0001,
dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='slim', trainer=trainer, criterion=criterion,
num_iterations=1, epochs_per_iteration=sparsifying_training_epochs, dependency_aware=dependency_aware,
dummy_input=dummy_input)
self.scale = scale
self.patch_optimizer_before(self._callback)
[docs] def validate_config(self, model, config_list):
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['BatchNorm2d'],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
if len(config_list) > 1:
logger.warning('Slim pruner only supports 1 configuration')
def _supported_dependency_aware(self):
return True
def _callback(self):
for _, wrapper in enumerate(self.get_modules_wrapper()):
wrapper.module.weight.grad.data.add_(self.scale * torch.sign(wrapper.module.weight.data))
[docs]class TaylorFOWeightFilterPruner(IterativePruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Currently only Conv2d is supported in TaylorFOWeightFilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
trainer : function
Function used to sparsify BatchNorm2d scaling factors.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function
Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
sparsifying_training_batches: int
The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
global_sort: bool
Only support TaylorFOWeightFilterPruner currently.
If prune the model in a global-sort way. If it is `True`, this pruner will prune
the model according to the global contributions information which means channel contributions
will be sorted globally and whether specific channel will be pruned depends on global information.
"""
def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1,
dependency_aware=False, dummy_input=None, global_sort=False):
super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer,
criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1,
epochs_per_iteration=1, dependency_aware=dependency_aware,
dummy_input=dummy_input)
self.masker.global_sort = global_sort
def _supported_dependency_aware(self):
return True
[docs]class ActivationAPoZRankFilterPruner(IterativePruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationAPoZRankFilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model
trainer: function
Function used to train the model.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function
Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
activation: str
The activation type.
sparsifying_training_batches: int
The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu',
sparsifying_training_batches=1, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='apoz', optimizer=optimizer, trainer=trainer,
criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input,
activation=activation, statistics_batch_num=sparsifying_training_batches, num_iterations=1,
epochs_per_iteration=1)
self.patch_optimizer(self.update_mask)
def _supported_dependency_aware(self):
return True
[docs]class ActivationMeanRankFilterPruner(IterativePruner):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
Supported keys:
- sparsity : How much percentage of convolutional filters are to be pruned.
- op_types : Only Conv2d is supported in ActivationMeanRankFilterPruner.
optimizer: torch.optim.Optimizer
Optimizer used to train model.
trainer: function
Function used to train the model.
Users should write this function as a normal function to train the Pytorch model
and include `model, optimizer, criterion, epoch` as function arguments.
criterion : function
Function used to calculate the loss between the target and the output.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
activation: str
The activation type.
sparsifying_training_batches: int
The number of batches to collect the contributions. Note that the number need to be less than the maximum batch number in one epoch.
dependency_aware: bool
If prune the model in a dependency-aware way. If it is `True`, this pruner will
prune the model according to the l2-norm of weights and the channel-dependency or
group-dependency of the model. In this way, the pruner will force the conv layers
that have dependencies to prune the same channels, so the speedup module can better
harvest the speed benefit from the pruned model. Note that, if this flag is set True
, the dummy_input cannot be None, because the pruner needs a dummy input to trace the
dependency between the conv layers.
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
"""
def __init__(self, model, config_list, optimizer, trainer, criterion, activation='relu',
sparsifying_training_batches=1, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='mean_activation', optimizer=optimizer, trainer=trainer,
criterion=criterion, dependency_aware=dependency_aware, dummy_input=dummy_input,
activation=activation, statistics_batch_num=sparsifying_training_batches, num_iterations=1,
epochs_per_iteration=1)
self.patch_optimizer(self.update_mask)
def _supported_dependency_aware(self):
return True