Source code for nni.algorithms.compression.pytorch.pruning.net_adapt_pruner

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import os
import copy
import json
import torch
from schema import And, Optional

from nni.utils import OptimizeMode

from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.utils.num_param_counter import get_total_num_weights
from .constants_pruner import PRUNER_DICT


_logger = logging.getLogger(__name__)


[docs]class NetAdaptPruner(Pruner): """ A Pytorch implementation of NetAdapt compression algorithm. Parameters ---------- model : pytorch model The model to be pruned. config_list : list Supported keys: - sparsity : The target overall sparsity. - op_types : The operation type to prune. short_term_fine_tuner : function function to short-term fine tune the masked model. This function should include `model` as the only parameter, and fine tune the model for a short term after each pruning iteration. Example:: def short_term_fine_tuner(model, epoch=3): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_loader = ... criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) model.train() for _ in range(epoch): for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() evaluator : function function to evaluate the masked model. This function should include `model` as the only parameter, and returns a scalar value. Example:: def evaluator(model): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") val_loader = ... model.eval() correct = 0 with torch.no_grad(): for data, target in val_loader: data, target = data.to(device), target.to(device) output = model(data) # get the index of the max log-probability pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() accuracy = correct / len(val_loader.dataset) return accuracy optimize_mode : str optimize mode, `maximize` or `minimize`, by default `maximize`. base_algo : str Base pruning algorithm. `level`, `l1`, `l2` or `fpgm`, by default `l1`. Given the sparsity distribution among the ops, the assigned `base_algo` is used to decide which filters/channels/weights to prune. sparsity_per_iteration : float sparsity to prune in each iteration. experiment_data_dir : str PATH to save experiment data, including the config_list generated for the base pruning algorithm and the performance of the pruned model. """ def __init__(self, model, config_list, short_term_fine_tuner, evaluator, optimize_mode='maximize', base_algo='l1', sparsity_per_iteration=0.05, experiment_data_dir='./'): # models used for iterative pruning and evaluation self._model_to_prune = copy.deepcopy(model) self._base_algo = base_algo super().__init__(model, config_list) self._short_term_fine_tuner = short_term_fine_tuner self._evaluator = evaluator self._optimize_mode = OptimizeMode(optimize_mode) # hyper parameters for NetAdapt algorithm self._sparsity_per_iteration = sparsity_per_iteration # overall pruning rate self._sparsity = config_list[0]['sparsity'] # config_list self._config_list_generated = [] self._experiment_data_dir = experiment_data_dir if not os.path.exists(self._experiment_data_dir): os.makedirs(self._experiment_data_dir) self._tmp_model_path = os.path.join(self._experiment_data_dir, 'tmp_model.pth')
[docs] def validate_config(self, model, config_list): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list List on pruning configs """ if self._base_algo == 'level': schema = PrunerSchema([{ Optional('sparsity'): And(float, lambda n: 0 < n < 1), Optional('op_types'): [str], Optional('op_names'): [str], Optional('exclude'): bool }], model, _logger) elif self._base_algo in ['l1', 'l2', 'fpgm']: schema = PrunerSchema([{ Optional('sparsity'): And(float, lambda n: 0 < n < 1), 'op_types': ['Conv2d'], Optional('op_names'): [str], Optional('exclude'): bool }], model, _logger) schema.validate(config_list)
[docs] def calc_mask(self, wrapper, **kwargs): return None
def _update_config_list(self, config_list, op_name, sparsity): ''' update sparsity of op_name in config_list ''' config_list_updated = copy.deepcopy(config_list) for idx, item in enumerate(config_list): if op_name in item['op_names']: config_list_updated[idx]['sparsity'] = sparsity return config_list_updated # if op_name is not in self._config_list_generated, create a new json item if self._base_algo in ['l1', 'l2', 'fpgm']: config_list_updated.append( {'sparsity': sparsity, 'op_types': ['Conv2d'], 'op_names': [op_name]}) elif self._base_algo == 'level': config_list_updated.append( {'sparsity': sparsity, 'op_names': [op_name]}) return config_list_updated def _get_op_num_weights_remained(self, op_name, module): ''' Get the number of weights remained after channel pruning with current sparsity Returns ------- int remained number of weights of the op ''' # if op is wrapped by the pruner for wrapper in self.get_modules_wrapper(): if wrapper.name == op_name: return wrapper.weight_mask.sum().item() # if op is not wrapped by the pruner return module.weight.data.numel() def _get_op_sparsity(self, op_name): for config in self._config_list_generated: if 'op_names' in config and op_name in config['op_names']: return config['sparsity'] return 0 def _calc_num_related_weights(self, op_name): ''' Calculate total number weights of the op and the next op, applicable only for models without dependencies among ops Parameters ---------- op_name : str Returns ------- int total number of all the realted (current and the next) op weights ''' num_weights = 0 flag_found = False previous_name = None previous_module = None for name, module in self._model_to_prune.named_modules(): if not flag_found and name != op_name and type(module).__name__ in ['Conv2d', 'Linear']: previous_name = name previous_module = module if not flag_found and name == op_name: _logger.debug("original module found: %s", name) num_weights = module.weight.data.numel() # consider related pruning in this op caused by previous op's pruning if previous_module: sparsity_previous_op = self._get_op_sparsity(previous_name) if sparsity_previous_op: _logger.debug( "decrease op's weights by %s due to previous op %s's pruning...", sparsity_previous_op, previous_name) num_weights *= (1-sparsity_previous_op) flag_found = True continue if flag_found and type(module).__name__ in ['Conv2d', 'Linear']: _logger.debug("related module found: %s", name) # channel/filter pruning crossing is considered here, so only the num_weights after channel pruning is valuable num_weights += self._get_op_num_weights_remained(name, module) break _logger.debug("num related weights of op %s : %d", op_name, num_weights) return num_weights
[docs] def compress(self): """ Compress the model. Returns ------- torch.nn.Module model with specified modules compressed. """ _logger.info('Starting NetAdapt Compression...') pruning_iteration = 0 current_sparsity = 0 delta_num_weights_per_iteration = \ int(get_total_num_weights(self._model_to_prune, ['Conv2d', 'Linear']) * self._sparsity_per_iteration) # stop condition while current_sparsity < self._sparsity: _logger.info('Pruning iteration: %d', pruning_iteration) # calculate target sparsity of this iteration target_sparsity = current_sparsity + self._sparsity_per_iteration # variable to store the info of the best layer found in this iteration best_op = {} for wrapper in self.get_modules_wrapper(): _logger.debug("op name : %s", wrapper.name) _logger.debug("op weights : %d", wrapper.weight_mask.numel()) _logger.debug("op left weights : %d", wrapper.weight_mask.sum().item()) current_op_sparsity = 1 - wrapper.weight_mask.sum().item() / wrapper.weight_mask.numel() _logger.debug("current op sparsity : %s", current_op_sparsity) # sparsity that this layer needs to prune to satisfy the requirement target_op_sparsity = current_op_sparsity + delta_num_weights_per_iteration / self._calc_num_related_weights(wrapper.name) if target_op_sparsity >= 1: _logger.info('Layer %s has no enough weights (remained) to prune', wrapper.name) continue config_list = self._update_config_list(self._config_list_generated, wrapper.name, target_op_sparsity) _logger.debug("config_list used : %s", config_list) pruner = PRUNER_DICT[self._base_algo](copy.deepcopy(self._model_to_prune), config_list) model_masked = pruner.compress() # Short-term fine tune the pruned model self._short_term_fine_tuner(model_masked) performance = self._evaluator(model_masked) _logger.info("Layer : %s, evaluation result after short-term fine tuning : %s", wrapper.name, performance) if not best_op \ or (self._optimize_mode is OptimizeMode.Maximize and performance > best_op['performance']) \ or (self._optimize_mode is OptimizeMode.Minimize and performance < best_op['performance']): _logger.debug("updating best layer to %s...", wrapper.name) # find weight mask of this layer for w in pruner.get_modules_wrapper(): if w.name == wrapper.name: masks = {'weight_mask': w.weight_mask, 'bias_mask': w.bias_mask} break best_op = { 'op_name': wrapper.name, 'sparsity': target_op_sparsity, 'performance': performance, 'masks': masks } # save model weights pruner.export_model(self._tmp_model_path) if not best_op: # decrease pruning step self._sparsity_per_iteration *= 0.5 _logger.info("No more layers to prune, decrease pruning step to %s", self._sparsity_per_iteration) continue # Pick the best layer to prune, update iterative information # update config_list self._config_list_generated = self._update_config_list( self._config_list_generated, best_op['op_name'], best_op['sparsity']) # update weights parameters self._model_to_prune.load_state_dict(torch.load(self._tmp_model_path)) # update mask of the chosen op for wrapper in self.get_modules_wrapper(): if wrapper.name == best_op['op_name']: for k in best_op['masks']: setattr(wrapper, k, best_op['masks'][k]) break current_sparsity = target_sparsity _logger.info('Pruning iteration %d finished, current sparsity: %s', pruning_iteration, current_sparsity) _logger.info('Layer %s seleted with sparsity %s, performance after pruning & short term fine-tuning : %s', best_op['op_name'], best_op['sparsity'], best_op['performance']) pruning_iteration += 1 self._final_performance = best_op['performance'] # load weights parameters self.load_model_state_dict(torch.load(self._tmp_model_path)) os.remove(self._tmp_model_path) _logger.info('----------Compression finished--------------') _logger.info('config_list generated: %s', self._config_list_generated) _logger.info("Performance after pruning: %s", self._final_performance) _logger.info("Masked sparsity: %.6f", current_sparsity) # save best config found and best performance with open(os.path.join(self._experiment_data_dir, 'search_result.json'), 'w') as jsonfile: json.dump({ 'performance': self._final_performance, 'config_list': json.dumps(self._config_list_generated) }, jsonfile) _logger.info('search history and result saved to foler : %s', self._experiment_data_dir) return self.bound_model