Source code for nni.algorithms.compression.pytorch.pruning.lottery_ticket

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import copy
import logging
import torch
from schema import And, Optional
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.compressor import Pruner
from .finegrained_pruning_masker import LevelPrunerMasker

logger = logging.getLogger('torch pruner')

[docs]class LotteryTicketPruner(Pruner): """ Parameters ---------- model : pytorch model The model to be pruned config_list : list Supported keys: - prune_iterations : The number of rounds for the iterative pruning. - sparsity : The final sparsity when the compression is done. optimizer : pytorch optimizer The optimizer for the model lr_scheduler : pytorch lr scheduler The lr scheduler for the model if used reset_weights : bool Whether reset weights and optimizer at the beginning of each round. """ def __init__(self, model, config_list, optimizer=None, lr_scheduler=None, reset_weights=True): # save init weights and optimizer self.reset_weights = reset_weights if self.reset_weights: self._model = model self._optimizer = optimizer self._model_state = copy.deepcopy(model.state_dict()) self._optimizer_state = copy.deepcopy(optimizer.state_dict()) self._lr_scheduler = lr_scheduler if lr_scheduler is not None: self._scheduler_state = copy.deepcopy(lr_scheduler.state_dict()) super().__init__(model, config_list, optimizer) self.curr_prune_iteration = None self.prune_iterations = config_list[0]['prune_iterations'] self.masker = LevelPrunerMasker(model, self)
[docs] def validate_config(self, model, config_list): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list Supported keys: - prune_iterations : The number of rounds for the iterative pruning. - sparsity : The final sparsity when the compression is done. """ schema = PrunerSchema([{ Optional('sparsity'): And(float, lambda n: 0 < n < 1), 'prune_iterations': And(int, lambda n: n > 0), Optional('op_types'): [str], Optional('op_names'): [str], Optional('exclude'): bool }], model, logger) schema.validate(config_list) assert len(set([x['prune_iterations'] for x in config_list])) == 1, 'The values of prune_iterations must be equal in your config'
def _calc_sparsity(self, sparsity): keep_ratio_once = (1 - sparsity) ** (1 / self.prune_iterations) curr_keep_ratio = keep_ratio_once ** self.curr_prune_iteration return max(1 - curr_keep_ratio, 0) def _calc_mask(self, wrapper, sparsity): weight = wrapper.module.weight.data if self.curr_prune_iteration == 0: mask = {'weight_mask': torch.ones(weight.shape).type_as(weight)} else: curr_sparsity = self._calc_sparsity(sparsity) mask = self.masker.calc_mask(sparsity=curr_sparsity, wrapper=wrapper) return mask
[docs] def calc_mask(self, wrapper, **kwargs): """ Generate mask for the given ``weight``. Parameters ---------- wrapper : Module The layer to be pruned Returns ------- tensor The mask for this weight, it is ```None``` because this pruner calculates and assigns masks in ```prune_iteration_start```, no need to do anything in this function. """ return None
[docs] def get_prune_iterations(self): """ Return the range for iterations. In the first prune iteration, masks are all one, thus, add one more iteration Returns ------- list A list for pruning iterations """ return range(self.prune_iterations + 1)
[docs] def prune_iteration_start(self): """ Control the pruning procedure on updated epoch number. Should be called at the beginning of the epoch. """ if self.curr_prune_iteration is None: self.curr_prune_iteration = 0 else: self.curr_prune_iteration += 1 assert self.curr_prune_iteration < self.prune_iterations + 1, 'Exceed the configured prune_iterations' modules_wrapper = self.get_modules_wrapper() modules_to_compress = self.get_modules_to_compress() for layer, config in modules_to_compress: module_wrapper = None for wrapper in modules_wrapper: if wrapper.name == layer.name: module_wrapper = wrapper break assert module_wrapper is not None sparsity = config.get('sparsity') mask = self._calc_mask(module_wrapper, sparsity) # TODO: directly use weight_mask is not good module_wrapper.weight_mask = mask['weight_mask'] # there is no mask for bias # reinit weights back to original after new masks are generated if self.reset_weights: # should use this member function to reset model weights self.load_model_state_dict(self._model_state) self._optimizer.load_state_dict(self._optimizer_state) if self._lr_scheduler is not None: self._lr_scheduler.load_state_dict(self._scheduler_state)