nni.algorithms.compression.v2.pytorch.base.pruner 源代码

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Module, Parameter

from .compressor import Compressor, LayerInfo, _setattr

_logger = logging.getLogger(__name__)

__all__ = ['Pruner']


[文档]class PrunerModuleWrapper(Module): """ Wrap a module to enable data parallel, forward method customization and buffer registeration. Parameters ---------- module The module user wants to compress. config The configurations that users specify for compression. module_name The name of the module to compress, wrapper module shares same name. """ def __init__(self, module: Module, module_name: str, config: Dict): super().__init__() # origin layer information self.module = module self.name = module_name # config information self.config = config self.weight = Parameter(torch.empty(self.module.weight.size())) # register buffer for mask self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) if hasattr(self.module, 'bias') and self.module.bias is not None: self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) self.bias = Parameter(torch.empty(self.module.bias.size())) else: self.register_buffer("bias_mask", None) def _weight2buffer(self): """ When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable. The best place to call this function is in `Pruner._wrap_model()`. """ self.weight.data = self.module.weight.data delattr(self.module, 'weight') self.module.register_buffer('weight', self.weight.data) if hasattr(self.module, 'bias') and self.module.bias is not None: self.bias.data = self.module.bias.data delattr(self.module, 'bias') self.module.register_buffer('bias', self.bias.data) def _weight2parameter(self): """ When don't need to record score or need to export the model, call `_weight2parameter()` to make the original weight trainable. The best place to call this function is in `Pruner._unwrap_model()`. """ delattr(self.module, 'weight') self.module.weight = Parameter(torch.empty(self.weight.size())) self.module.weight.data = torch.mul(self.weight, self.weight_mask) if hasattr(self.module, 'bias') and self.module.bias is not None: delattr(self.module, 'bias') self.module.bias = Parameter(torch.empty(self.bias.size())) self.module.bias.data = torch.mul(self.bias, self.bias_mask) def forward(self, *inputs): # apply mask to weight, bias self.module.weight = torch.mul(self.weight, self.weight_mask) if hasattr(self.module, 'bias') and self.module.bias is not None: self.module.bias = torch.mul(self.bias, self.bias_mask) return self.module(*inputs)
[文档]class Pruner(Compressor): """ The abstract class for pruning algorithm. Inherit this class and implement the `_reset_tools` to customize a pruner. """ def reset(self, model: Optional[Module] = None, config_list: Optional[List[Dict]] = None): super().reset(model=model, config_list=config_list) def _wrap_modules(self, layer: LayerInfo, config: Dict): """ Create a wrapper module to replace the original one. Parameters ---------- layer The layer to instrument the mask. config The configuration for generating the mask. """ _logger.debug("Module detected to compress : %s.", layer.name) wrapper = PrunerModuleWrapper(layer.module, layer.name, config) assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name # move newly registered buffers to the same device of weight wrapper.to(layer.module.weight.device) return wrapper # The following `_wrap_model`, `_unwrap_model`, `get_origin2wrapped_parameter_name_map` can merge to `Compressor`, # if quantizer use the similar structure wrapper. def _wrap_model(self): """ Wrap all modules that needed to be compressed. Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper. """ if not self.is_wrapped: for _, wrapper in reversed(self.get_modules_wrapper().items()): _setattr(self.bound_model, wrapper.name, wrapper) wrapper._weight2buffer() self.is_wrapped = True def _unwrap_model(self): """ Unwrap all modules that needed to be compressed. Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module. """ if self.is_wrapped: for _, wrapper in self.get_modules_wrapper().items(): _setattr(self.bound_model, wrapper.name, wrapper.module) wrapper._weight2parameter() self.is_wrapped = False
[文档] def get_origin2wrapped_parameter_name_map(self) -> Dict[str, str]: """ Get the name mapping of parameters from original model to wrapped model. Returns ------- Dict[str, str] Return a dict `{original_model_parameter_name: wrapped_model_parameter_name}` """ if self.is_wrapped: wrapped_param_names = {id(param): name for name, param in self.bound_model.named_parameters()} self._unwrap_model() parameter_name_map = {} for name, param in self.bound_model.named_parameters(): # If the parameter name in under wrapped module is `xxx.weight` or `xxx.bias`, the name will not change after wrap. # If the parameter name in under wrapped module is others, the name `xxx.param` will change to `xxx.module.param` after wrap. parameter_name_map[name] = wrapped_param_names[id(param)] if id(param) in wrapped_param_names else name self._wrap_model() return parameter_name_map else: raise Exception('When only the model is wrapped can get the parameter_name_map.')
[文档] def load_masks(self, masks: Dict[str, Dict[str, Tensor]]): """ Load an exist masks on the wrapper. You can train the model with an exist masks after load the masks. Parameters ---------- masks The masks dict with format {'op_name': {'weight': mask, 'bias': mask}}. """ wrappers = self.get_modules_wrapper() for name, layer_mask in masks.items(): assert name in wrappers, '{} is not in wrappers of this pruner, can not apply the mask.'.format(name) if layer_mask.get('weight') is not None: assert hasattr(wrappers[name], 'weight_mask'), 'There is no attribute weight_mask in wrapper.' setattr(wrappers[name], 'weight_mask', layer_mask.get('weight')) if layer_mask.get('bias') is not None: assert hasattr(wrappers[name], 'bias_mask'), 'There is no attribute bias_mask in wrapper.' setattr(wrappers[name], 'bias_mask', layer_mask.get('bias'))
[文档] def compress(self) -> Tuple[Module, Dict[str, Dict[str, Tensor]]]: """ Returns ------- Tuple[Module, Dict] Return the wrapped model and mask. """ return self.bound_model, {}
# NOTE: need refactor dim with supporting list
[文档] def show_pruned_weights(self, dim: int = 0): """ Log the simulated prune sparsity. Parameters ---------- dim The pruned dim. """ for _, wrapper in self.get_modules_wrapper().items(): weight_mask = wrapper.weight_mask mask_size = weight_mask.size() if len(mask_size) == 1: index = torch.nonzero(weight_mask.abs() != 0, as_tuple=False).tolist() else: sum_idx = list(range(len(mask_size))) sum_idx.remove(dim) index = torch.nonzero(weight_mask.abs().sum(sum_idx) != 0, as_tuple=False).tolist() _logger.info(f'simulated prune {wrapper.name} remain/total: {len(index)}/{weight_mask.size(dim)}')
[文档] def export_model(self, model_path: str, mask_path: Optional[str] = None): """ Export pruned model weights, masks and onnx model(optional) Parameters ---------- model_path Path to save pruned model state_dict. The weight and bias have already multiplied the masks. mask_path Path to save mask dict. """ assert self.bound_model is not None, 'The bound model reference has been cleared.' assert model_path is not None, 'model_path must be specified.' mask_dict = {} self._unwrap_model() for name, wrapper in self.get_modules_wrapper().items(): weight_mask = wrapper.weight_mask bias_mask = wrapper.bias_mask if weight_mask is not None: mask_sum = weight_mask.sum().item() mask_num = weight_mask.numel() _logger.debug('Layer: %s Sparsity: %.4f', name, 1 - mask_sum / mask_num) wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask) if bias_mask is not None: wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask) # save mask to dict mask_dict[name] = {"weight": weight_mask, "bias": bias_mask} torch.save(self.bound_model.state_dict(), model_path) _logger.info('Model state_dict saved to %s', model_path) if mask_path is not None: torch.save(mask_dict, mask_path) _logger.info('Mask dict saved to %s', mask_path) self._wrap_model()