Source code for nni.algorithms.compression.pytorch.pruning.transformer_pruner
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from schema import And, Optional
from nni.common.graph_utils import TorchModuleGraph
from nni.compression.pytorch.utils.shape_dependency import AttentionWeightDependency
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.compressor import Pruner
from . import L1WeightHeadMasker, L2WeightHeadMasker, L1ActivationHeadMasker, L2ActivationHeadMasker, TaylorFOHeadMasker
__all__ = ['TransformerHeadPruner']
MASKER_DICT = {
'l1_weight': L1WeightHeadMasker,
'l2_weight': L2WeightHeadMasker,
'l1_activation': L1ActivationHeadMasker,
'l2_activation': L2ActivationHeadMasker,
'taylorfo': TaylorFOHeadMasker
}
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
[docs]class TransformerHeadPruner(Pruner):
"""
A pruner specialized for pruning attention heads in models belong to the transformer family.
Parameters
----------
model : torch.nn.Module
Model to be pruned. Expect a model from transformers library (e.g., BertModel).
This pruner can work with other customized transformer models, but some ranking modes might fail.
config_list : list
Supported keys:
- sparsity : This is to specify the sparsity operations to be compressed to.
- op_types : Optional. Operation types to prune. (Should be 'Linear' for this pruner.)
- op_names : Optional. Operation names to prune.
head_hidden_dim : int
Dimension of the hidden dimension of each attention head. (e.g., 64 for BERT)
We assume that this head_hidden_dim is constant across the entire model.
attention_name_groups : list (Optional)
List of groups of names for weights of each attention layer. Each element should be a four-element list, with
the first three corresponding to Q_proj, K_proj, V_proj (in any order) and the last one being output_proj.
dummy_input : torch.Tensor (Optional)
Input to model's forward method, used to infer module grouping if attention_name_groups is not specified.
This tensor is used by the underlying torch.jit.trace to infer the module graph.
ranking_criterion : str
The criterion for ranking attention heads. Currently we support:
- l1_weight: l1 norm of Q_proj, K_proj, and V_proj
- l2_weight: l2 norm of Q_proj, K_proj, and V_proj
- l1_activation: l1 norm of the output of attention computation
- l2_activation: l2 norm of the output of attention computation
- taylorfo: l1 norm of the output of attention computation * gradient for this output
(check more details in the masker documentation)
global_sort : bool
Whether rank the heads globally or locally before deciding heads to prune.
num_iterations : int
Number of pruning iterations. Defaults to 1 (ont-shot pruning). If num_iterations > 1, the pruner will split
the sparsity specified in config_list uniformly and assign a fraction to each pruning iteration.
epochs_per_iteration : int
Number of finetuning epochs before the next pruning iteration.
Only used when num_iterations > 1.
If num_iterations is 1, then no finetuning is performed by the pruner after pruning.
optimizer: torch.optim.Optimizer
Optimizer used to train model
trainer: function
Function used to finetune the model between pruning iterations.
Only used when num_iterations > 1 or ranking_criterion is 'taylorfo'.
Users should write this function as a normal function to train the PyTorch model and include
`model, optimizer, criterion, epoch` as function arguments. Note that the trainer is also used for collecting
gradients for pruning if ranking_criterion is 'taylorfo'. In that case, ``epoch=None`` will be passed.
criterion: function
Function used to calculate the loss between the target and the output.
Only used when num_iterations > 1 or ranking_criterion is 'taylorfo'.
For example, you can use ``torch.nn.CrossEntropyLoss()`` as input.
forward_runner: function
Function used to perform a "dry run" on the model on the entire train/validation dataset in order to collect
data for pruning required by the criteria 'l1_activation' or 'l2_activation'.
Only used when ranking_criterion is 'l1_activation' or 'l2_activation'.
Users should write this function as a normal function that accepts a PyTorch model and runs forward on the model
using the entire train/validation dataset. This function is not expected to perform any backpropagation or
parameter updates.
"""
def __init__(self, model, config_list,head_hidden_dim, attention_name_groups=None, dummy_input=None,
ranking_criterion='l1_weight', global_sort=False, num_iterations=1, epochs_per_iteration=1,
optimizer=None, trainer=None, criterion=None, forward_runner=None,
**algo_kwargs):
super().__init__(model, config_list)
self.head_hidden_dim = int(head_hidden_dim)
self.attention_name_groups = attention_name_groups
self.dummy_input = dummy_input
self.ranking_criterion = ranking_criterion
assert self.ranking_criterion in ['l1_weight', 'l2_weight', 'l1_activation', 'l2_activation', 'taylorfo'], \
"Unsupported ranking criteria."
self.global_sort = global_sort
self.num_iterations = int(num_iterations)
assert self.num_iterations >= 1, "num_iterations must be greater than or equal to 1"
self.epochs_per_iteration = int(epochs_per_iteration)
self._optimizer = optimizer
self._trainer = trainer
self._criterion = criterion
self._forward_runner = forward_runner
if self.ranking_criterion in ['taylorfo'] or num_iterations > 1:
assert self._trainer is not None
assert self._optimizer is not None
if self.ranking_criterion in ['l1_activation', 'l2_activation']:
assert self._forward_runner is not None
# Group generation: one group per attention layer, four weights per group
self.masking_groups = []
if self.attention_name_groups is not None:
logger.info("Note: weights for the same attention layer are grouped using the given attention_name_groups.")
self.group_weights_by_name()
else:
assert self.dummy_input is not None
logger.info("Note: weights for the same attention layer are grouped using model graph.")
self._unwrap_model()
self.group_weight_names_by_graph()
self._wrap_model()
# Group sanity check
self.validate_weight_groups()
# Remove any mistakenly captured ungrouped modules
self._unwrap_model()
self.remove_ungrouped_modules()
self._wrap_model()
self.masker = MASKER_DICT[ranking_criterion](model, self, self.head_hidden_dim, **algo_kwargs)
self.pruned_heads = {i: set() for i in range(len(self.masking_groups))}
[docs] def group_weights_by_name(self):
"""
Populate self.masking_groups using the groups specified by user in attention_name_groups.
"""
assert len(self.masking_groups) == 0
# build up masking groups
name2group = {}
for layer_idx, layer in enumerate(self.attention_name_groups):
errmsg = 'Each name group must contain 4 weights, with the first three corresponding to Q_proj, K_proj, ' \
'V_proj (in any order) and the last one being output_proj.'
assert len(layer) == 4, errmsg
self.masking_groups.append([])
for weight in layer:
name2group[weight] = layer_idx
# group wrappers
for wrapper in self.get_modules_wrapper():
if wrapper.name in name2group:
wrapper.group_idx = name2group[wrapper.name]
self.masking_groups[name2group[wrapper.name]].append(wrapper)
logger.info('Grouping updated:')
logger.info([[x.name for x in group] for group in self.masking_groups])
[docs] def group_weight_names_by_graph(self):
"""
Populate self.attention_name_groups by running inference on the module graph.
Currently, the group inferred AttentionWeightDependency is limited to a set of four weights, with the first
three corresponding to Q_proj, K_proj, V_proj (in any order) and the last one being output_proj.
"""
try:
module_graph = TorchModuleGraph(self.bound_model, self.dummy_input)
dependency_tracer = AttentionWeightDependency(traced_model=module_graph.trace)
self.attention_name_groups = dependency_tracer.dependency_sets
self.group_weights_by_name()
except Exception as e:
raise RuntimeError('Graph trace failed: please check dummy_input, or specify attention_name_groups.\n'
'Exception message: ' + str(e))
[docs] def validate_weight_groups(self):
"""
Sanity checks:
- Q, K, V projection weights in each groups must have the same shape
- output projection weight shape must match total hidden dimension (inferred from Q, K, V projection)
- Four weights in a group must have the same sparsity in their config
- If global_sort is specified, all weights must have the same sparsity
- head_hidden_dim must be a divisor of the output dimension of the projection weights (i.e., the resulting
head number must be an integer)
"""
errmsg = 'Attention weight group sanity check not passed'
sparsity = None
for group in self.masking_groups:
# allow empty groups - may be caused by config list filtering
if len(group) == 0:
continue
assert len(group) == 4, errmsg + ': each group must have four weights'
assert group[0].module.weight.size() == group[1].module.weight.size() and \
group[1].module.weight.size() == group[2].module.weight.size(), \
errmsg + ': the dimensions of Q, K, V projection matrices must be the same '
assert group[0].module.weight.size()[0] == group[3].module.weight.size()[1], \
errmsg + ': the dimension of attention results must match with input for output projection'
assert group[0].config['sparsity'] == group[1].config['sparsity'] == \
group[2].config['sparsity'] == group[3].config['sparsity'], \
errmsg + ': the sparsity of matrices in the same layer must be the same'
if sparsity is None:
sparsity = group[0].config['sparsity']
if self.global_sort:
assert sparsity == group[0].config['sparsity'], \
errmsg + ': for global_sort=True, the sparsity for all modules must be the same'
assert group[0].module.weight.size(0) % self.head_hidden_dim == 0, \
errmsg + ': head_hidden_dim must be a divisor of the output dimension of the projection weights'
[docs] def remove_ungrouped_modules(self):
"""
Remove non-attention weights that might be mistakenly captured by a simplified config_list.
Also update the corresponding list of layer information (self.modules_to_compress)
"""
care_of_modules = set([x for layer in self.masking_groups for x in layer])
modules_wrapper_new, modules_to_compress_new = [], []
for wrapper, layer_info in zip(self.modules_wrapper, self.modules_to_compress):
if wrapper in care_of_modules:
modules_wrapper_new.append(wrapper)
modules_to_compress_new.append(layer_info)
self.modules_wrapper = modules_wrapper_new
self.modules_to_compress = modules_to_compress_new
[docs] def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
[docs] def compress(self):
for pruning_iter in range(self.num_iterations):
if self.ranking_criterion in ['l1_activation', 'l2_activation']:
training = self.bound_model.training
self.bound_model.eval()
self._forward_runner(self.bound_model) # dry run, forward only
self.update_mask()
self.bound_model.train(training)
elif self.ranking_criterion in ['taylorfo']:
self._trainer(self.bound_model, optimizer=self._optimizer, criterion=self._criterion, epoch=None)
self.update_mask()
else:
self.update_mask()
# for iterative pruning, if not the last iteration, finetune before next iteration
# Then, reset the maskers (may create additional hooks)
if self.num_iterations > 1 and pruning_iter != self.num_iterations - 1:
for e in range(self.epochs_per_iteration):
self._trainer(self.bound_model, optimizer=self._optimizer, criterion=self._criterion, epoch=e+1)
self.masker.reset()
logger.info('Pruned heads after iteration %i', pruning_iter)
logger.info(self.pruned_heads)
[docs] def update_mask(self):
"""
Calculate and update masks for each masking group. If global_sort is set, the masks for all groups are
calculated altogether, and then the groups are updated individually.
"""
masks_for_all_groups = None
if self.global_sort:
masks_for_all_groups = self._calc_mask_global()
assert len(masks_for_all_groups) == len(self.masking_groups)
for group_idx, layer_weight_group in enumerate(self.masking_groups):
if self.global_sort:
masks = masks_for_all_groups[group_idx]
else:
masks = self._calc_mask(layer_weight_group)
if masks is not None:
for i, mask in enumerate(masks):
for mask_type in mask:
assert hasattr(layer_weight_group[i], mask_type), \
"there is no attribute '%s' in wrapper on %s" % (mask_type, layer_weight_group[i])
setattr(layer_weight_group[i], mask_type, mask[mask_type])
logger.debug(f'mask updated: {layer_weight_group[i].name} {mask_type}')
def _calc_mask(self, weight_group):
"""
Calculate mask for each group using only layer-local information.
When global_sort is set for the pruner, _calc_mask_global should be called instead of this function.
Parameters
----------
weight_group : list
A list of four wrappers generated by self.group_weights_by_name().
Returns
-------
masks : list
A four element list corresponding to the masks for each element in the four-element weight group.
Each element in masks is a dict with keys "weight_mask" and "bias_mask" (optional).
masks can be None if the underlying masker returns None. This means that the mask calculation fails.
The calling function can try recalculate the mask at a later time. Note that the calling function might need
to call masker.reset() before attempting to recalculate the mask.
"""
iter_sparsity = weight_group[0].config['sparsity'] / self.num_iterations
masks = self.masker.calc_mask(sparsity=iter_sparsity, weight_group=weight_group)
return masks
def _calc_mask_global(self):
"""
Calculate mask for all groups using global information.
Returns
-------
masks_list : list
A list corresponding to the masks for each weight group in self.masking_groups. Each element in the
returned mask_list is a four-element list corresponding to the masks for each element in a four-element
weight group.
"""
if len(self.get_modules_wrapper()) == 0:
return []
overall_sparsity = self.get_modules_wrapper()[0].config['sparsity'] / self.num_iterations
n_heads_total = 0
for group in self.masking_groups:
if len(group) != 0:
q_proj, _, _, _ = group
n_heads_total += int(q_proj.module.weight.size()[0] / self.head_hidden_dim)
n_heads_to_prune = int(n_heads_total * overall_sparsity)
return self.masker.calc_mask_global(n_heads_to_prune)
[docs] def calc_mask(self, wrapper, **kwargs):
raise RuntimeError("Applications should directly call TransformerHeadPruner's update_mask() method.")