# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import csv
import logging
from collections import OrderedDict
import numpy as np
import torch.nn as nn
# FIXME: I don't know where "utils" should be
SUPPORTED_OP_NAME = ['Conv2d', 'Conv1d']
SUPPORTED_OP_TYPE = [getattr(nn, name) for name in SUPPORTED_OP_NAME]
logger = logging.getLogger('Sensitivity_Analysis')
logger.setLevel(logging.INFO)
[docs]class SensitivityAnalysis:
def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None):
"""
Perform sensitivity analysis for this model.
Parameters
----------
model : torch.nn.Module
the model to perform sensitivity analysis
val_func : function
validation function for the model. Due to
different models may need different dataset/criterion
, therefore the user need to cover this part by themselves.
In the val_func, the model should be tested on the validation dateset,
and the validation accuracy/loss should be returned as the output of val_func.
There are no restrictions on the input parameters of the val_function.
User can use the val_args, val_kwargs parameters in analysis
to pass all the parameters that val_func needed.
sparsities : list
The sparsity list provided by users. This parameter is set when the user
only wants to test some specific sparsities. In the sparsity list, each element
is a sparsity value which means how much weight the pruner should prune. Take
[0.25, 0.5, 0.75] for an example, the SensitivityAnalysis will prune 25% 50% 75%
weights gradually for each layer.
prune_type : str
The pruner type used to prune the conv layers, default is 'l1',
and 'l2', 'fine-grained' is also supported.
early_stop_mode : str
If this flag is set, the sensitivity analysis
for a conv layer will early stop when the validation metric(
for example, accurracy/loss) has alreay meet the threshold. We
support four different early stop modes: minimize, maximize, dropped,
raised. The default value is None, which means the analysis won't stop
until all given sparsities are tested. This option should be used with
early_stop_value together.
minimize: The analysis stops when the validation metric return by the val_func
lower than early_stop_value.
maximize: The analysis stops when the validation metric return by the val_func
larger than early_stop_value.
dropped: The analysis stops when the validation metric has dropped by early_stop_value.
raised: The analysis stops when the validation metric has raised by early_stop_value.
early_stop_value : float
This value is used as the threshold for different earlystop modes.
This value is effective only when the early_stop_mode is set.
"""
from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT
self.model = model
self.val_func = val_func
self.target_layer = OrderedDict()
self.ori_state_dict = copy.deepcopy(self.model.state_dict())
self.target_layer = {}
self.sensitivities = {}
if sparsities is not None:
self.sparsities = sorted(sparsities)
else:
self.sparsities = np.arange(0.1, 1.0, 0.1)
self.sparsities = [np.round(x, 2) for x in self.sparsities]
self.Pruner = PRUNER_DICT[prune_type]
self.early_stop_mode = early_stop_mode
self.early_stop_value = early_stop_value
self.ori_metric = None # original validation metric for the model
# already_pruned is for the iterative sensitivity analysis
# For example, sensitivity_pruner iteratively prune the target
# model according to the sensitivity. After each round of
# pruning, the sensitivity_pruner will test the new sensitivity
# for each layer
self.already_pruned = {}
self.model_parse()
@property
def layers_count(self):
return len(self.target_layer)
def model_parse(self):
for name, submodel in self.model.named_modules():
for op_type in SUPPORTED_OP_TYPE:
if isinstance(submodel, op_type):
self.target_layer[name] = submodel
self.already_pruned[name] = 0
def _need_to_stop(self, ori_metric, cur_metric):
"""
Judge if meet the stop conditon(early_stop, min_threshold,
max_threshold).
Parameters
----------
ori_metric : float
original validation metric
cur_metric : float
current validation metric
Returns
-------
stop : bool
if stop the sensitivity analysis
"""
if self.early_stop_mode is None:
# early stop mode is not enable
return False
assert self.early_stop_value is not None
if self.early_stop_mode == 'minimize':
if cur_metric < self.early_stop_value:
return True
elif self.early_stop_mode == 'maximize':
if cur_metric > self.early_stop_value:
return True
elif self.early_stop_mode == 'dropped':
if cur_metric < ori_metric - self.early_stop_value:
return True
elif self.early_stop_mode == 'raised':
if cur_metric > ori_metric + self.early_stop_value:
return True
return False
[docs] def analysis(self, val_args=None, val_kwargs=None, specified_layers=None):
"""
This function analyze the sensitivity to pruning for
each conv layer in the target model.
If start and end are not set, we analyze all the conv
layers by default. Users can specify several layers to
analyze or parallelize the analysis process easily through
the start and end parameter.
Parameters
----------
val_args : list
args for the val_function
val_kwargs : dict
kwargs for the val_funtion
specified_layers : list
list of layer names to analyze sensitivity.
If this variable is set, then only analyze
the conv layers that specified in the list.
User can also use this option to parallelize
the sensitivity analysis easily.
Returns
-------
sensitivities : dict
dict object that stores the trajectory of the
accuracy/loss when the prune ratio changes
"""
if val_args is None:
val_args = []
if val_kwargs is None:
val_kwargs = {}
# Get the original validation metric(accuracy/loss) before pruning
# Get the accuracy baseline before starting the analysis.
self.ori_metric = self.val_func(*val_args, **val_kwargs)
namelist = list(self.target_layer.keys())
if specified_layers is not None:
# only analyze several specified conv layers
namelist = list(filter(lambda x: x in specified_layers, namelist))
for name in namelist:
self.sensitivities[name] = {}
for sparsity in self.sparsities:
# here the sparsity is the relative sparsity of the
# the remained weights
# Calculate the actual prune ratio based on the already pruned ratio
real_sparsity = (
1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name]
# TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary
# I think the L1/L2 Pruner should specify the op_types automaticlly
# according to the op_names
cfg = [{'sparsity': real_sparsity, 'op_names': [
name], 'op_types': ['Conv2d']}]
pruner = self.Pruner(self.model, cfg)
pruner.compress()
val_metric = self.val_func(*val_args, **val_kwargs)
logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f',
name, real_sparsity, val_metric)
self.sensitivities[name][sparsity] = val_metric
pruner._unwrap_model()
del pruner
# check if the current metric meet the stop condition
if self._need_to_stop(self.ori_metric, val_metric):
break
# reset the weights pruned by the pruner, because the
# input sparsities is sorted, so we donnot need to reset
# weight of the layer when the sparsity changes, instead,
# we only need reset the weight when the pruning layer changes.
self.model.load_state_dict(self.ori_state_dict)
return self.sensitivities
[docs] def export(self, filepath):
"""
Export the results of the sensitivity analysis
to a csv file. The firstline of the csv file describe the content
structure. The first line is constructed by 'layername' and sparsity
list. Each line below records the validation metric returned by val_func
when this layer is under different sparsities. Note that, due to the early_stop
option, some layers may not have the metrics under all sparsities.
layername, 0.25, 0.5, 0.75
conv1, 0.6, 0.55
conv2, 0.61, 0.57, 0.56
Parameters
----------
filepath : str
Path of the output file
"""
str_sparsities = [str(x) for x in self.sparsities]
header = ['layername'] + str_sparsities
with open(filepath, 'w') as csvf:
csv_w = csv.writer(csvf)
csv_w.writerow(header)
for layername in self.sensitivities:
row = []
row.append(layername)
for sparsity in sorted(self.sensitivities[layername].keys()):
row.append(self.sensitivities[layername][sparsity])
csv_w.writerow(row)
[docs] def update_already_pruned(self, layername, ratio):
"""
Set the already pruned ratio for the target layer.
"""
self.already_pruned[layername] = ratio
[docs] def load_state_dict(self, state_dict):
"""
Update the weight of the model
"""
self.ori_state_dict = copy.deepcopy(state_dict)
self.model.load_state_dict(self.ori_state_dict)