Source code for nni.algorithms.compression.pytorch.pruning.structured_pruning_masker

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import math
import numpy as np
import torch
from .weight_masker import WeightMasker

__all__ = ['L1FilterPrunerMasker', 'L2FilterPrunerMasker', 'FPGMPrunerMasker',
           'TaylorFOWeightFilterPrunerMasker', 'ActivationAPoZRankFilterPrunerMasker',
           'ActivationMeanRankFilterPrunerMasker', 'SlimPrunerMasker', 'AMCWeightMasker']

logger = logging.getLogger('torch filter pruners')


[docs]class StructuredWeightMasker(WeightMasker): """ A structured pruning masker base class that prunes convolutional layer filters. Parameters ---------- model: nn.Module model to be pruned pruner: Pruner A Pruner instance used to prune the model preserve_round: int after pruning, preserve filters/channels round to `preserve_round`, for example: for a Conv2d layer, output channel is 32, sparsity is 0.2, if preserve_round is 1 (no preserve round), then there will be int(32 * 0.2) = 6 filters pruned, and 32 - 6 = 26 filters are preserved. If preserve_round is 4, preserved filters will be round up to 28 (which can be divided by 4) and only 4 filters are pruned. """ def __init__(self, model, pruner, preserve_round=1, dependency_aware=False, global_sort=False): self.model = model self.pruner = pruner self.preserve_round = preserve_round self.dependency_aware = dependency_aware self.global_sort = global_sort
[docs] def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs): """ calculate the mask for `wrapper`. Parameters ---------- sparsity: float/list of float The target sparsity of the wrapper. If we calculate the mask in the normal way, then sparsity is a float number. In contrast, if we calculate the mask in the dependency-aware way, sparsity is a list of float numbers, each float number corressponds to a sparsity of a layer. wrapper: PrunerModuleWrapper/list of PrunerModuleWrappers The wrapper of the target layer. If we calculate the mask in the normal way, then `wrapper` is an instance of PrunerModuleWrapper, else `wrapper` is a list of PrunerModuleWrapper. wrapper_idx: int/list of int The index of the wrapper. depen_kwargs: dict The kw_args for the dependency-aware mode. """ if self.global_sort: # if the global_sort switch is on, calculate the mask based # on global model information return self._global_calc_mask(sparsity, wrapper, wrapper_idx) elif not self.dependency_aware: # calculate the mask in the normal way, each layer calculate its # own mask separately return self._normal_calc_mask(sparsity, wrapper, wrapper_idx) else: # if the dependency_aware switch is on, then calculate the mask # in the dependency-aware way return self._dependency_calc_mask(sparsity, wrapper, wrapper_idx, **depen_kwargs)
def _get_current_state(self, sparsity, wrapper, wrapper_idx=None): """ Some pruner may prune the layers in a iterative way. In each pruning iteration, we may get the current state of this wrapper/layer, and continue to prune this layer based on the current state. This function is to get the current pruning state of the target wrapper/layer. Parameters ---------- sparsity: float pruning ratio, preserved weight ratio is `1 - sparsity` wrapper: PrunerModuleWrapper layer wrapper of this layer wrapper_idx: int index of this wrapper in pruner's all wrappers Returns ------- base_mask: dict dict object that stores the mask of this wrapper in this iteration, if it is the first iteration, then we create a new mask with all ones. If there is already a mask in this wrapper, then we return the existing mask. weight: tensor the current weight of this layer num_prune: int how many filters we should prune """ msg = 'module type {} is not supported!'.format(wrapper.type) assert wrapper.type == 'Conv2d', msg weight = wrapper.module.weight.data bias = None if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None: bias = wrapper.module.bias.data if wrapper.weight_mask is None: mask_weight = torch.ones(weight.size()).type_as(weight).detach() else: mask_weight = wrapper.weight_mask.clone() if bias is not None: if wrapper.bias_mask is None: mask_bias = torch.ones(bias.size()).type_as(bias).detach() else: mask_bias = wrapper.bias_mask.clone() else: mask_bias = None mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias} num_total = weight.size(0) num_prune = int(num_total * sparsity) if self.preserve_round > 1: num_preserve = num_total - num_prune num_preserve = int( math.ceil(num_preserve * 1. / self.preserve_round) * self.preserve_round) if num_preserve > num_total: num_preserve = int(math.floor( num_total * 1. / self.preserve_round) * self.preserve_round) num_prune = num_total - num_preserve # weight*mask_weight: apply base mask for iterative pruning return mask, weight * mask_weight, num_prune def _global_calc_mask(self, sparsity, wrapper, wrapper_idx=None): num_prune = self._get_global_num_prune(wrapper, wrapper_idx) mask, weight, _ = self._get_current_state( sparsity, wrapper, wrapper_idx) return self.get_mask(mask, weight, num_prune, wrapper, wrapper_idx) def _normal_calc_mask(self, sparsity, wrapper, wrapper_idx=None): """ Calculate the mask of given layer. Parameters ---------- sparsity: float pruning ratio, preserved weight ratio is `1 - sparsity` wrapper: PrunerModuleWrapper layer wrapper of this layer wrapper_idx: int index of this wrapper in pruner's all wrappers Returns ------- dict dictionary for storing masks, keys of the dict: 'weight_mask': weight mask tensor 'bias_mask': bias mask tensor (optional) """ mask, weight, num_prune = self._get_current_state( sparsity, wrapper, wrapper_idx) num_total = weight.size(0) if num_total < 2 or num_prune < 1: return mask return self.get_mask(mask, weight, num_prune, wrapper, wrapper_idx) def _common_channel_to_prune(self, sparsities, wrappers, wrappers_idx, channel_dsets, groups): """ Calculate the common channels should be pruned by all the layers in this group. This function is for filter pruning of Conv layers. if want to support the dependency-aware mode for others ops, you need to inherit this class and overwrite `_common_channel_to_prune`. Parameters ---------- sparsities : list List of float that specify the sparsity for each conv layer. wrappers : list List of wrappers groups : list The number of the filter groups of each layer. wrappers_idx : list The indexes of the wrappers """ # sparsity configs for each wrapper # sparsities = [_w.config['sparsity'] for _w in wrappers] # check the type of the input wrappers for _w in wrappers: msg = 'module type {} is not supported!'.format(_w.type) assert _w.type == 'Conv2d', msg # Among the dependent layers, the layer with smallest # sparsity determines the final benefit of the speedup # module. To better harvest the speed benefit, we need # to ensure that these dependent layers have at least # `min_sparsity` pruned channel are the same. if len(channel_dsets) == len(wrappers): # all the layers in the dependency sets are pruned min_sparsity = min(sparsities) else: # not all the layers in the dependency set # are pruned min_sparsity = 0 # donnot prune the channels that we cannot harvest the speed from sparsities = [min_sparsity] * len(sparsities) # find the max number of the filter groups of the dependent # layers. The group constraint of this dependency set is decided # by the layer with the max groups. # should use the least common multiple for all the groups # the max_group is lower than the channel_count, because # the number of the filter is always divisible by the number of the group max_group = np.lcm.reduce(groups) channel_count = wrappers[0].module.weight.data.size(0) device = wrappers[0].module.weight.device channel_sum = torch.zeros(channel_count).to(device) for _w, _w_idx in zip(wrappers, wrappers_idx): # calculate the L1/L2 sum for all channels c_sum = self.get_channel_sum(_w, _w_idx) if c_sum is None: # if the channel sum cannot be calculated # now, return None return None channel_sum += c_sum # prune the same `min_sparsity` channels based on channel_sum # for all the layers in the channel sparsity target_pruned = int(channel_count * min_sparsity) # pruned_per_group may be zero, for example dw conv pruned_per_group = int(target_pruned / max_group) group_step = int(channel_count / max_group) channel_masks = [] for gid in range(max_group): _start = gid * group_step _end = (gid + 1) * group_step if pruned_per_group > 0: threshold = torch.topk( channel_sum[_start: _end], pruned_per_group, largest=False)[0].max() group_mask = torch.gt(channel_sum[_start:_end], threshold) else: group_mask = torch.ones(group_step).to(device) channel_masks.append(group_mask) channel_masks = torch.cat(channel_masks, dim=0) pruned_channel_index = ( channel_masks == False).nonzero().squeeze(1).tolist() logger.info('Prune the %s channels for all dependent', ','.join([str(x) for x in pruned_channel_index])) return channel_masks def _dependency_calc_mask(self, sparsities, wrappers, wrappers_idx, channel_dsets, groups): """ Calculate the masks for the layers in the same dependency sets. Similar to the traditional original calc_mask, _dependency_calc_mask will prune the target layers based on the L1/L2 norm of the weights. However, StructuredWeightMasker prunes the filter completely based on the L1/L2 norm of each filter. In contrast, _dependency_calc_mask will try to satisfy the channel/group dependency(see nni.compression.torch. utils.shape_dependency for details). Specifically, _dependency_calc_mask will try to prune the same channels for the layers that have channel dependency. In addition, this mask calculator will also ensure that the number of filters pruned in each group is the same(meet the group dependency). Parameters ---------- sparsities : list List of float that specify the sparsity for each conv layer. wrappers : list List of wrappers groups : list The number of the filter groups of each layer. wrappers_idx : list The indexes of the wrappers """ channel_masks = self._common_channel_to_prune( sparsities, wrappers, wrappers_idx, channel_dsets, groups) # calculate the mask for each layer based on channel_masks, first # every layer will prune the same channels masked in channel_masks. # If the sparsity of a layers is larger than min_sparsity, then it # will continue prune sparsity - min_sparsity channels to meet the sparsity # config. masks = {} for _pos, _w in enumerate(wrappers): _w_idx = wrappers_idx[_pos] sparsity = sparsities[_pos] name = _w.name # _tmp_mask = self._normal_calc_mask( # sparsity, _w, _w_idx, channel_masks) base_mask, current_weight, num_prune = self._get_current_state( sparsity, _w, _w_idx) num_total = current_weight.size(0) if num_total < 2 or num_prune < 1: masks[name] = base_mask continue _tmp_mask = self.get_mask( base_mask, current_weight, num_prune, _w, _w_idx, channel_masks) if _tmp_mask is None: # if the mask calculation fails return None masks[name] = _tmp_mask return masks
[docs] def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): """ Calculate the mask of given layer. Parameters ---------- base_mask: dict The basic mask with the same shape of weight, all item in the basic mask is 1. weight: tensor the module weight to be pruned num_prune: int Num of filters to prune wrapper: PrunerModuleWrapper layer wrapper of this layer wrapper_idx: int index of this wrapper in pruner's all wrappers channel_masks: Tensor If mask some channels for this layer in advance. In the dependency-aware mode, before calculating the masks for each layer, we will calculate a common mask for all the layers in the dependency set. For the pruners that doesnot support dependency-aware mode, they can just ignore this parameter. Returns ------- dict dictionary for storing masks """ raise NotImplementedError( '{} get_mask is not implemented'.format(self.__class__.__name__))
[docs] def get_channel_sum(self, wrapper, wrapper_idx): """ Calculate the importance weight for each channel. If want to support the dependency-aware mode for this one-shot pruner, this function must be implemented. Parameters ---------- wrapper: PrunerModuleWrapper layer wrapper of this layer wrapper_idx: int index of this wrapper in pruner's all wrappers Returns ------- tensor Tensor that indicates the importance of each channel """ raise NotImplementedError( '{} get_channel_sum is not implemented'.format(self.__class__.__name__))
class L1FilterPrunerMasker(StructuredWeightMasker): """ A structured pruning algorithm that prunes the filters of smallest magnitude weights sum in the convolution layers to achieve a preset level of network sparsity. Hao Li, Asim Kadav, Igor Durdanovic, Hanan Samet and Hans Peter Graf, "PRUNING FILTERS FOR EFFICIENT CONVNETS", 2017 ICLR https://arxiv.org/abs/1608.08710 """ def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): # get the l1-norm sum for each filter w_abs_structured = self.get_channel_sum(wrapper, wrapper_idx) if channel_masks is not None: # if we need to mask some channels in advance w_abs_structured = w_abs_structured * channel_masks threshold = torch.topk(w_abs_structured.view(-1), num_prune, largest=False)[0].max() mask_weight = torch.gt(w_abs_structured, threshold)[ :, None, None, None].expand_as(weight).type_as(weight) mask_bias = torch.gt(w_abs_structured, threshold).type_as( weight).detach() if base_mask['bias_mask'] is not None else None return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias} def get_channel_sum(self, wrapper, wrapper_idx): weight = wrapper.module.weight.data filters = weight.shape[0] w_abs = weight.abs() w_abs_structured = w_abs.view(filters, -1).sum(dim=1) return w_abs_structured class L2FilterPrunerMasker(StructuredWeightMasker): """ A structured pruning algorithm that prunes the filters with the smallest L2 norm of the weights. """ def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): # get the l2-norm sum for each filter w_l2_norm = self.get_channel_sum(wrapper, wrapper_idx) if channel_masks is not None: # if we need to mask some channels in advance w_l2_norm = w_l2_norm * channel_masks threshold = torch.topk( w_l2_norm.view(-1), num_prune, largest=False)[0].max() mask_weight = torch.gt(w_l2_norm, threshold)[ :, None, None, None].expand_as(weight).type_as(weight) mask_bias = torch.gt(w_l2_norm, threshold).type_as( weight).detach() if base_mask['bias_mask'] is not None else None return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias} def get_channel_sum(self, wrapper, wrapper_idx): weight = wrapper.module.weight.data filters = weight.shape[0] w = weight.view(filters, -1) w_l2_norm = torch.sqrt((w ** 2).sum(dim=1)) return w_l2_norm class FPGMPrunerMasker(StructuredWeightMasker): """ A filter pruner via geometric median. "Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration", https://arxiv.org/pdf/1811.00250.pdf """ def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): min_gm_idx = self._get_min_gm_kernel_idx( num_prune, wrapper, wrapper_idx, channel_masks) for idx in min_gm_idx: base_mask['weight_mask'][idx] = 0. if base_mask['bias_mask'] is not None: base_mask['bias_mask'][idx] = 0. return base_mask def _get_min_gm_kernel_idx(self, num_prune, wrapper, wrapper_idx, channel_masks): channel_dist = self.get_channel_sum(wrapper, wrapper_idx) if channel_masks is not None: channel_dist = channel_dist * channel_masks dist_list = [(channel_dist[i], i) for i in range(channel_dist.size(0))] min_gm_kernels = sorted(dist_list, key=lambda x: x[0])[:num_prune] return [x[1] for x in min_gm_kernels] def _get_distance_sum(self, weight, out_idx): """ Calculate the total distance between a specified filter (by out_idex and in_idx) and all other filters. Parameters ---------- weight: Tensor convolutional filter weight out_idx: int output channel index of specified filter, this method calculates the total distance between this specified filter and all other filters. Returns ------- float32 The total distance """ logger.debug('weight size: %s', weight.size()) assert len(weight.size()) in [3, 4], 'unsupported weight shape' w = weight.view(weight.size(0), -1) anchor_w = w[out_idx].unsqueeze(0).expand(w.size(0), w.size(1)) x = w - anchor_w x = (x * x).sum(-1) x = torch.sqrt(x) return x.sum() def get_channel_sum(self, wrapper, wrapper_idx): weight = wrapper.module.weight.data assert len(weight.size()) in [3, 4] dist_list = [] for out_i in range(weight.size(0)): dist_sum = self._get_distance_sum(weight, out_i) dist_list.append(dist_sum) return torch.Tensor(dist_list).to(weight.device) class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): """ A structured pruning algorithm that prunes the filters with the smallest importance approximations based on the first order taylor expansion on the weight. Molchanov, Pavlo and Mallya, Arun and Tyree, Stephen and Frosio, Iuri and Kautz, Jan, "Importance Estimation for Neural Network Pruning", CVPR 2019. http://jankautz.com/publications/Importance4NNPruning_CVPR19.pdf """ def __init__(self, model, pruner, statistics_batch_num=1): super().__init__(model, pruner) self.statistics_batch_num = statistics_batch_num self.pruner.iterations = 0 self.pruner.set_wrappers_attribute("contribution", None) self.pruner.patch_optimizer(self.calc_contributions) self.global_threshold = None def _get_global_threshold(self): channel_contribution_list = [] for wrapper_idx, wrapper in enumerate(self.pruner.get_modules_wrapper()): channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) wrapper_size = wrapper.module.weight.size().numel() channel_size = wrapper.module.weight.size(0) contribution_expand = channel_contribution.expand(int(wrapper_size / channel_size), channel_size).reshape(-1) channel_contribution_list.append(contribution_expand) all_channel_contributions = torch.cat(channel_contribution_list) k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity']) self.global_threshold = torch.topk( all_channel_contributions.view(-1), k, largest=False)[0].max() def _get_global_num_prune(self, wrapper, wrapper_idx): if self.global_threshold is None: self._get_global_threshold() weight = wrapper.module.weight.data filters = weight.size(0) channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) num_prune = channel_contribution[channel_contribution < self.global_threshold].size()[0] if num_prune == filters: num_prune -= 1 return num_prune def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) if channel_contribution is None: # iteration is not enough return None if channel_masks is not None: channel_contribution = channel_contribution * channel_masks prune_indices = torch.argsort(channel_contribution)[:num_prune] for idx in prune_indices: base_mask['weight_mask'][idx] = 0. if base_mask['bias_mask'] is not None: base_mask['bias_mask'][idx] = 0. return base_mask def calc_contributions(self): """ Calculate the estimated importance of filters as a sum of individual contribution based on the first order taylor expansion. """ if self.pruner.iterations >= self.statistics_batch_num: return for wrapper in self.pruner.get_modules_wrapper(): filters = wrapper.module.weight.size(0) contribution = ( wrapper.module.weight * wrapper.module.weight.grad).data.pow(2).view(filters, -1).sum(dim=1) if wrapper.contribution is None: wrapper.contribution = contribution else: wrapper.contribution += contribution self.pruner.iterations += 1 def get_channel_sum(self, wrapper, wrapper_idx): if self.pruner.iterations < self.statistics_batch_num: return None if wrapper.contribution is None: return None return wrapper.contribution class ActivationFilterPrunerMasker(StructuredWeightMasker): def __init__(self, model, pruner, statistics_batch_num=1, activation='relu'): super().__init__(model, pruner) self.statistics_batch_num = statistics_batch_num self.pruner.hook_id = self._add_activation_collector(self.pruner) self.pruner.iterations = 0 self.pruner.patch_optimizer(self._iteration_counter) assert activation in ['relu', 'relu6'] if activation == 'relu': self.pruner.activation = torch.nn.functional.relu elif activation == 'relu6': self.pruner.activation = torch.nn.functional.relu6 else: self.pruner.activation = None def _iteration_counter(self): self.pruner.iterations += 1 def _add_activation_collector(self, pruner): def collector(collected_activation): def hook(module_, input_, output): collected_activation.append( pruner.activation(output.detach().cpu())) return hook pruner.collected_activation = {} pruner._fwd_hook_id += 1 pruner._fwd_hook_handles[pruner._fwd_hook_id] = [] for wrapper_idx, wrapper in enumerate(pruner.get_modules_wrapper()): pruner.collected_activation[wrapper_idx] = [] handle = wrapper.register_forward_hook( collector(pruner.collected_activation[wrapper_idx])) pruner._fwd_hook_handles[pruner._fwd_hook_id].append(handle) return pruner._fwd_hook_id class ActivationAPoZRankFilterPrunerMasker(ActivationFilterPrunerMasker): """ A structured pruning algorithm that prunes the filters with the smallest APoZ(average percentage of zeros) of output activations. Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang, "Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016. https://arxiv.org/abs/1607.03250 """ def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): apoz = self.get_channel_sum(wrapper, wrapper_idx) if apoz is None: # the collected activations are not enough return None if channel_masks is not None: apoz = apoz * channel_masks prune_indices = torch.argsort(apoz)[:num_prune] for idx in prune_indices: base_mask['weight_mask'][idx] = 0. if base_mask['bias_mask'] is not None: base_mask['bias_mask'][idx] = 0. if self.pruner.hook_id in self.pruner._fwd_hook_handles: self.pruner.remove_activation_collector(self.pruner.hook_id) return base_mask def _calc_apoz(self, activations): """ Calculate APoZ(average percentage of zeros) of activations. Parameters ---------- activations : list Layer's output activations Returns ------- torch.Tensor Filter's APoZ(average percentage of zeros) of the activations """ activations = torch.cat(activations, 0) _eq_zero = torch.eq(activations, torch.zeros_like(activations)) _apoz = torch.sum(_eq_zero, dim=(0, 2, 3), dtype=torch.float64) / \ torch.numel(_eq_zero[:, 0, :, :]) return torch.ones_like(_apoz) - _apoz def get_channel_sum(self, wrapper, wrapper_idx): assert wrapper_idx is not None activations = self.pruner.collected_activation[wrapper_idx] if len(activations) < self.statistics_batch_num: # collected activations is not enough return None return self._calc_apoz(activations).to(wrapper.module.weight.device) class ActivationMeanRankFilterPrunerMasker(ActivationFilterPrunerMasker): """ A structured pruning algorithm that prunes the filters with the smallest mean value of output activations. Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz, "Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017. https://arxiv.org/abs/1611.06440 """ def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): mean_activation = self.get_channel_sum(wrapper, wrapper_idx) if mean_activation is None: # the collected activation is not enough return None if channel_masks is not None: mean_activation = mean_activation * channel_masks prune_indices = torch.argsort(mean_activation)[:num_prune] for idx in prune_indices: base_mask['weight_mask'][idx] = 0. if base_mask['bias_mask'] is not None: base_mask['bias_mask'][idx] = 0. # if len(activations) < self.statistics_batch_num, the code # cannot reach here if self.pruner.hook_id in self.pruner._fwd_hook_handles: self.pruner.remove_activation_collector(self.pruner.hook_id) return base_mask def _cal_mean_activation(self, activations): """ Calculate mean value of activations. Parameters ---------- activations : list Layer's output activations Returns ------- torch.Tensor Filter's mean value of the output activations """ activations = torch.cat(activations, 0) mean_activation = torch.mean(activations, dim=(0, 2, 3)) return mean_activation def get_channel_sum(self, wrapper, wrapper_idx): assert wrapper_idx is not None activations = self.pruner.collected_activation[wrapper_idx] if len(activations) < self.statistics_batch_num: return None # the memory overhead here is acceptable, because only # the mean_activation tensor returned by _cal_mean_activation # is transfer to gpu. return self._cal_mean_activation(activations).to(wrapper.module.weight.device) class SlimPrunerMasker(WeightMasker): """ A structured pruning algorithm that prunes channels by pruning the weights of BN layers. Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang "Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV https://arxiv.org/pdf/1708.06519.pdf """ def __init__(self, model, pruner, **kwargs): super().__init__(model, pruner) self.global_threshold = None def _get_global_threshold(self): weight_list = [] for (layer, _) in self.pruner.get_modules_to_compress(): weight_list.append(layer.module.weight.data.abs().clone()) all_bn_weights = torch.cat(weight_list) k = int(all_bn_weights.shape[0] * self.pruner.config_list[0]['sparsity']) self.global_threshold = torch.topk( all_bn_weights.view(-1), k, largest=False)[0].max() print(f'set global threshold to {self.global_threshold}') def calc_mask(self, sparsity, wrapper, wrapper_idx=None): assert wrapper.type == 'BatchNorm2d', 'SlimPruner only supports 2d batch normalization layer pruning' if self.global_threshold is None: self._get_global_threshold() weight = wrapper.module.weight.data.clone() if wrapper.weight_mask is not None: # apply base mask for iterative pruning weight = weight * wrapper.weight_mask base_mask = torch.ones(weight.size()).type_as(weight).detach() mask = {'weight_mask': base_mask.detach( ), 'bias_mask': base_mask.clone().detach()} filters = weight.size(0) num_prune = int(filters * sparsity) if filters >= 2 and num_prune >= 1: w_abs = weight.abs() mask_weight = torch.gt( w_abs, self.global_threshold).type_as(weight) mask_bias = mask_weight.clone() mask = {'weight_mask': mask_weight.detach( ), 'bias_mask': mask_bias.detach()} return mask def least_square_sklearn(X, Y): from sklearn.linear_model import LinearRegression reg = LinearRegression(fit_intercept=False) reg.fit(X, Y) return reg.coef_ class AMCWeightMasker(WeightMasker): """ Weight maskser class for AMC pruner. Currently, AMCPruner only supports pruning kernel size 1x1 pointwise Conv2d layer. Before using this class to prune kernels, AMCPruner collected input and output feature maps for each layer, the features maps are flattened and save into wrapper.input_feat and wrapper.output_feat. Parameters ---------- model: nn.Module model to be pruned pruner: Pruner A Pruner instance used to prune the model preserve_round: int after pruning, preserve filters/channels round to `preserve_round`, for example: for a Conv2d layer, output channel is 32, sparsity is 0.2, if preserve_round is 1 (no preserve round), then there will be int(32 * 0.2) = 6 filters pruned, and 32 - 6 = 26 filters are preserved. If preserve_round is 4, preserved filters will be round up to 28 (which can be divided by 4) and only 4 filters are pruned. """ def __init__(self, model, pruner, preserve_round=1): self.model = model self.pruner = pruner self.preserve_round = preserve_round def calc_mask(self, sparsity, wrapper, wrapper_idx=None, preserve_idx=None): """ Calculate the mask of given layer. Parameters ---------- sparsity: float pruning ratio, preserved weight ratio is `1 - sparsity` wrapper: PrunerModuleWrapper layer wrapper of this layer wrapper_idx: int index of this wrapper in pruner's all wrappers Returns ------- dict dictionary for storing masks, keys of the dict: 'weight_mask': weight mask tensor 'bias_mask': bias mask tensor (optional) """ msg = 'module type {} is not supported!'.format(wrapper.type) assert wrapper.type in ['Conv2d', 'Linear'], msg weight = wrapper.module.weight.data bias = None if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None: bias = wrapper.module.bias.data if wrapper.weight_mask is None: mask_weight = torch.ones(weight.size()).type_as(weight).detach() else: mask_weight = wrapper.weight_mask.clone() if bias is not None: if wrapper.bias_mask is None: mask_bias = torch.ones(bias.size()).type_as(bias).detach() else: mask_bias = wrapper.bias_mask.clone() else: mask_bias = None mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias} num_total = weight.size(1) num_prune = int(num_total * sparsity) if self.preserve_round > 1: num_preserve = num_total - num_prune num_preserve = int( math.ceil(num_preserve * 1. / self.preserve_round) * self.preserve_round) if num_preserve > num_total: num_preserve = num_total num_prune = num_total - num_preserve if (num_total < 2 or num_prune < 1) and preserve_idx is None: return mask return self.get_mask(mask, weight, num_preserve, wrapper, wrapper_idx, preserve_idx) def get_mask(self, base_mask, weight, num_preserve, wrapper, wrapper_idx, preserve_idx): w = weight.data.cpu().numpy() if wrapper.type == 'Linear': w = w[:, :, None, None] if preserve_idx is None: importance = np.abs(w).sum((0, 2, 3)) # sum magnitude along C_in, sort descend sorted_idx = np.argsort(-importance) d_prime = num_preserve preserve_idx = sorted_idx[:d_prime] # to preserve index else: d_prime = len(preserve_idx) assert len(preserve_idx) == d_prime mask = np.zeros(w.shape[1], bool) mask[preserve_idx] = True # reconstruct, X, Y <= [N, C] X, Y = wrapper.input_feat, wrapper.output_feat masked_X = X[:, mask] if w.shape[2] == 1: # 1x1 conv or fc rec_weight = least_square_sklearn(X=masked_X, Y=Y) rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) # (C_out, K_h, K_w, C_in') rec_weight = np.transpose(rec_weight, (0, 3, 1, 2)) # (C_out, C_in', K_h, K_w) rec_weight_pad = np.zeros_like(w) # pylint: disable=all rec_weight_pad[:, mask, :, :] = rec_weight rec_weight = rec_weight_pad if wrapper.type == 'Linear': rec_weight = rec_weight.squeeze() assert len(rec_weight.shape) == 2 # now assign wrapper.module.weight.data = torch.from_numpy(rec_weight).to(weight.device) mask_weight = torch.zeros_like(weight) if wrapper.type == 'Linear': mask_weight[:, preserve_idx] = 1. if base_mask['bias_mask'] is not None and wrapper.module.bias is not None: mask_bias = torch.ones_like(wrapper.module.bias) else: mask_weight[:, preserve_idx, :, :] = 1. mask_bias = None return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}