Source code for nni.algorithms.compression.pytorch.pruning.sensitivity_pruner

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import csv
import copy
import json
import logging
import torch

from schema import And, Optional
from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.utils.sensitivity_analysis import SensitivityAnalysis

from .constants_pruner import PRUNER_DICT


MAX_PRUNE_RATIO_PER_ITER = 0.95

_logger = logging.getLogger('Sensitivity_Pruner')
_logger.setLevel(logging.INFO)

[docs]class SensitivityPruner(Pruner): """ This function prune the model based on the sensitivity for each layer. Parameters ---------- model: torch.nn.Module model to be compressed evaluator: function validation function for the model. This function should return the accuracy of the validation dataset. The input parameters of evaluator can be specified in the parameter `eval_args` and 'eval_kwargs' of the compress function if needed. Example: >>> def evaluator(model): >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> val_loader = ... >>> model.eval() >>> correct = 0 >>> with torch.no_grad(): >>> for data, target in val_loader: >>> data, target = data.to(device), target.to(device) >>> output = model(data) >>> # get the index of the max log-probability >>> pred = output.argmax(dim=1, keepdim=True) >>> correct += pred.eq(target.view_as(pred)).sum().item() >>> accuracy = correct / len(val_loader.dataset) >>> return accuracy finetuner: function finetune function for the model. This parameter is not essential, if is not None, the sensitivity pruner will finetune the model after pruning in each iteration. The input parameters of finetuner can be specified in the parameter of compress called `finetune_args` and `finetune_kwargs` if needed. Example: >>> def finetuner(model, epoch=3): >>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu") >>> train_loader = ... >>> criterion = torch.nn.CrossEntropyLoss() >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.01) >>> model.train() >>> for _ in range(epoch): >>> for _, (data, target) in enumerate(train_loader): >>> data, target = data.to(device), target.to(device) >>> optimizer.zero_grad() >>> output = model(data) >>> loss = criterion(output, target) >>> loss.backward() >>> optimizer.step() base_algo: str base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. sparsity_proportion_calc: function This function generate the sparsity proportion between the conv layers according to the sensitivity analysis results. We provide a default function to quantify the sparsity proportion according to the sensitivity analysis results. Users can also customize this function according to their needs. The input of this function is a dict, for example : {'conv1' : {0.1: 0.9, 0.2 : 0.8}, 'conv2' : {0.1: 0.9, 0.2 : 0.8}}, in which, 'conv1' and is the name of the conv layer, and 0.1:0.9 means when the sparsity of conv1 is 0.1 (10%), the model's val accuracy equals to 0.9. sparsity_per_iter: float The sparsity of the model that the pruner try to prune in each iteration. acc_drop_threshold : float The hyperparameter used to quantifiy the sensitivity for each layer. checkpoint_dir: str The dir path to save the checkpoints during the pruning. """ def __init__(self, model, config_list, evaluator, finetuner=None, base_algo='l1', sparsity_proportion_calc=None, sparsity_per_iter=0.1, acc_drop_threshold=0.05, checkpoint_dir=None): self.base_algo = base_algo self.model = model super(SensitivityPruner, self).__init__(model, config_list) # unwrap the model self._unwrap_model() _logger.debug(str(self.model)) self.evaluator = evaluator self.finetuner = finetuner self.analyzer = SensitivityAnalysis( self.model, self.evaluator, prune_type=base_algo, \ early_stop_mode='dropped', early_stop_value=acc_drop_threshold) # Get the original accuracy of the pretrained model self.ori_acc = None # Copy the original weights before pruning self.ori_state_dict = copy.deepcopy(self.model.state_dict()) self.sensitivities = {} # Save the weight count for each layer self.weight_count = {} self.weight_sum = 0 # Map the layer name to the layer module self.named_module = {} self.Pruner = PRUNER_DICT[self.base_algo] # Count the total weight count of the model for name, submodule in self.model.named_modules(): self.named_module[name] = submodule if name in self.analyzer.target_layer: # Currently, only count the weights in the conv layers # else the fully connected layer (which contains # the most weights) may make the pruner prune the # model too hard # if hasattr(submodule, 'weight'): # Count all the weights of the model self.weight_count[name] = submodule.weight.data.numel() self.weight_sum += self.weight_count[name] # function to generate the sparsity proportion between the conv layers if sparsity_proportion_calc is None: self.sparsity_proportion_calc = self._max_prune_ratio else: self.sparsity_proportion_calc = sparsity_proportion_calc # The ratio of remained weights is 1.0 at the begining self.remained_ratio = 1.0 self.sparsity_per_iter = sparsity_per_iter self.acc_drop_threshold = acc_drop_threshold self.checkpoint_dir = checkpoint_dir
[docs] def validate_config(self, model, config_list): """ Parameters ---------- model : torch.nn.module Model to be pruned config_list : list List on pruning configs """ if self.base_algo == 'level': schema = PrunerSchema([{ Optional('sparsity'): And(float, lambda n: 0 < n < 1), Optional('op_types'): [str], Optional('op_names'): [str], Optional('exclude'): bool }], model, _logger) elif self.base_algo in ['l1', 'l2', 'fpgm']: schema = PrunerSchema([{ Optional('sparsity'): And(float, lambda n: 0 < n < 1), 'op_types': ['Conv2d'], Optional('op_names'): [str], Optional('exclude'): bool }], model, _logger) schema.validate(config_list)
[docs] def load_sensitivity(self, filepath): """ load the sensitivity results exported by the sensitivity analyzer """ assert os.path.exists(filepath) with open(filepath, 'r') as csvf: csv_r = csv.reader(csvf) header = next(csv_r) sparsities = [float(x) for x in header[1:]] sensitivities = {} for row in csv_r: layername = row[0] accuracies = [float(x) for x in row[1:]] sensitivities[layername] = {} for i, accuracy in enumerate(accuracies): sensitivities[layername][sparsities[i]] = accuracy return sensitivities
def _max_prune_ratio(self, ori_acc, threshold, sensitivities): """ Find the maximum prune ratio for a single layer whose accuracy drop is lower than the threshold. Parameters ---------- ori_acc: float Original accuracy threshold: float Accuracy drop threshold sensitivities: dict The dict object that stores the sensitivity results for each layer. For example: {'conv1' : {0.1: 0.9, 0.2 : 0.8}} Returns ------- max_ratios: dict return the maximum prune ratio for each layer. For example: {'conv1':0.1, 'conv2':0.2} """ max_ratio = {} for layer in sensitivities: prune_ratios = sorted(sensitivities[layer].keys()) last_ratio = 0 for ratio in prune_ratios: last_ratio = ratio cur_acc = sensitivities[layer][ratio] if cur_acc + threshold < ori_acc: break max_ratio[layer] = last_ratio return max_ratio
[docs] def normalize(self, ratios, target_pruned): """ Normalize the prune ratio of each layer according to the total already pruned ratio and the final target total pruning ratio Parameters ---------- ratios: Dict object that save the prune ratio for each layer target_pruned: The amount of the weights expected to be pruned in this iteration Returns ------- new_ratios: return the normalized prune ratios for each layer. """ w_sum = 0 _Max = 0 for layername, ratio in ratios.items(): wcount = self.weight_count[layername] w_sum += ratio * wcount * \ (1-self.analyzer.already_pruned[layername]) target_count = self.weight_sum * target_pruned for layername in ratios: ratios[layername] = ratios[layername] * target_count / w_sum _Max = max(_Max, ratios[layername]) # Cannot Prune too much in a single iteration # If a layer's prune ratio is larger than the # MAX_PRUNE_RATIO_PER_ITER we rescal all prune # ratios under this threshold if _Max > MAX_PRUNE_RATIO_PER_ITER: for layername in ratios: ratios[layername] = ratios[layername] * \ MAX_PRUNE_RATIO_PER_ITER / _Max return ratios
[docs] def create_cfg(self, ratios): """ Generate the cfg_list for the pruner according to the prune ratios. Parameters --------- ratios: For example: {'conv1' : 0.2} Returns ------- cfg_list: For example: [{'sparsity':0.2, 'op_names':['conv1'], 'op_types':['Conv2d']}] """ cfg_list = [] for layername in ratios: prune_ratio = ratios[layername] remain = 1 - self.analyzer.already_pruned[layername] sparsity = remain * prune_ratio + \ self.analyzer.already_pruned[layername] if sparsity > 0: # Pruner does not allow the prune ratio to be zero cfg = {'sparsity': sparsity, 'op_names': [ layername], 'op_types': ['Conv2d']} cfg_list.append(cfg) return cfg_list
[docs] def current_sparsity(self): """ The sparsity of the weight. """ pruned_weight = 0 for layer_name in self.analyzer.already_pruned: w_count = self.weight_count[layer_name] prune_ratio = self.analyzer.already_pruned[layer_name] pruned_weight += w_count * prune_ratio return pruned_weight / self.weight_sum
[docs] def compress(self, eval_args=None, eval_kwargs=None, finetune_args=None, finetune_kwargs=None, resume_sensitivity=None): """ This function iteratively prune the model according to the results of the sensitivity analysis. Parameters ---------- eval_args: list eval_kwargs: list& dict Parameters for the val_funtion, the val_function will be called like evaluator(\*eval_args, \*\*eval_kwargs) finetune_args: list finetune_kwargs: dict Parameters for the finetuner function if needed. resume_sensitivity: resume the sensitivity results from this file. """ # pylint suggest not use the empty list and dict # as the default input parameter if not eval_args: eval_args = [] if not eval_kwargs: eval_kwargs = {} if not finetune_args: finetune_args = [] if not finetune_kwargs: finetune_kwargs = {} if self.ori_acc is None: self.ori_acc = self.evaluator(*eval_args, **eval_kwargs) assert isinstance(self.ori_acc, float) or isinstance(self.ori_acc, int) if not resume_sensitivity: self.sensitivities = self.analyzer.analysis( val_args=eval_args, val_kwargs=eval_kwargs) else: self.sensitivities = self.load_sensitivity(resume_sensitivity) self.analyzer.sensitivities = self.sensitivities # the final target sparsity of the model target_ratio = 1 - self.config_list[0]['sparsity'] cur_ratio = self.remained_ratio ori_acc = self.ori_acc iteration_count = 0 if self.checkpoint_dir is not None: os.makedirs(self.checkpoint_dir, exist_ok=True) modules_wrapper_final = None while cur_ratio > target_ratio: iteration_count += 1 # Each round have three steps: # 1) Get the current sensitivity for each layer(the sensitivity # of each layer may change during the pruning) # 2) Prune each layer according the sensitivies # 3) finetune the model _logger.info('Current base accuracy %f', ori_acc) _logger.info('Remained %f weights', cur_ratio) # determine the sparsity proportion between different # layers according to the sensitivity result proportion = self.sparsity_proportion_calc( ori_acc, self.acc_drop_threshold, self.sensitivities) new_pruneratio = self.normalize(proportion, self.sparsity_per_iter) cfg_list = self.create_cfg(new_pruneratio) if not cfg_list: _logger.error('The threshold is too small, please set a larger threshold') return self.model _logger.debug('Pruner Config: %s', str(cfg_list)) cfg_str = ['%s:%.3f'%(cfg['op_names'][0], cfg['sparsity']) for cfg in cfg_list] _logger.info('Current Sparsities: %s', ','.join(cfg_str)) pruner = self.Pruner(self.model, cfg_list) pruner.compress() pruned_acc = self.evaluator(*eval_args, **eval_kwargs) _logger.info('Accuracy after pruning: %f', pruned_acc) finetune_acc = pruned_acc if self.finetuner is not None: # if the finetune function is None, then skip the finetune self.finetuner(*finetune_args, **finetune_kwargs) finetune_acc = self.evaluator(*eval_args, **eval_kwargs) _logger.info('Accuracy after finetune: %f', finetune_acc) ori_acc = finetune_acc # unwrap the pruner pruner._unwrap_model() # update the already prune ratio of each layer befor the new # sensitivity analysis for layer_cfg in cfg_list: name = layer_cfg['op_names'][0] sparsity = layer_cfg['sparsity'] self.analyzer.already_pruned[name] = sparsity # update the cur_ratio cur_ratio = 1 - self.current_sparsity() modules_wrapper_final = pruner.get_modules_wrapper() del pruner _logger.info('Currently remained weights: %f', cur_ratio) if self.checkpoint_dir is not None: checkpoint_name = 'Iter_%d_finetune_acc_%.5f_sparsity_%.4f' % ( iteration_count, finetune_acc, cur_ratio) checkpoint_path = os.path.join( self.checkpoint_dir, '%s.pth' % checkpoint_name) cfg_path = os.path.join( self.checkpoint_dir, '%s_pruner.json' % checkpoint_name) sensitivity_path = os.path.join( self.checkpoint_dir, '%s_sensitivity.csv' % checkpoint_name) torch.save(self.model.state_dict(), checkpoint_path) with open(cfg_path, 'w') as jf: json.dump(cfg_list, jf) self.analyzer.export(sensitivity_path) if cur_ratio > target_ratio: # If this is the last prune iteration, skip the time-consuming # sensitivity analysis self.analyzer.load_state_dict(self.model.state_dict()) self.sensitivities = self.analyzer.analysis( val_args=eval_args, val_kwargs=eval_kwargs) _logger.info('After Pruning: %.2f weights remains', cur_ratio) self.modules_wrapper = modules_wrapper_final self._wrap_model() return self.model
[docs] def calc_mask(self, wrapper, **kwargs): return None