# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation_v1 import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer, QuantGrad
from ..utils.quantization.literal import (
PER_CHANNEL_QUANT_SCHEME,
QuantScheme,
QuantDtype,
QuantType
)
from ..utils.quantization.literal import BN_FOLD_TAG
from ..utils.quantization.settings import LayerQuantSetting
from ..utils.quantization.utils import (
calculate_qmin_qmax,
get_min_max_value,
get_quant_shape
)
logger = logging.getLogger(__name__)
class QATGrad(QuantGrad):
@staticmethod
def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
mask = (tensor_q < qmin) | (tensor_q > qmax)
grad_output[mask] = 0
return grad_output
def update_quantization_param(bits, rmin, rmax, dtype, scheme):
"""
calculate the `zero_point` and `scale`.
Parameters
----------
bits : int
quantization bits length
rmin : Tensor
min value of real value
rmax : Tensor
max value of real value
dtype : QuantDtype
quantized data type
scheme : QuantScheme
quantization scheme to be used
Returns
-------
float, float
"""
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
# I think this is for activations that need to be pad in the training.
# However this is a default behavior in PyTorch quantization observer.
# So we also make it a default behavior
rmin = torch.min(rmin, torch.zeros_like(rmin))
rmax = torch.max(rmax, torch.zeros_like(rmax))
zero_point = torch.zeros_like(rmin)
# todo: there is no need to calculate qmin and qmax again
qmin, qmax = calculate_qmin_qmax(bits, dtype)
if scheme in [QuantScheme.PER_TENSOR_SYMMETRIC, QuantScheme.PER_CHANNEL_SYMMETRIC]:
abs_max = torch.max(torch.abs(rmin), torch.abs(rmax))
scale = abs_max / (float(qmax - qmin) / 2)
if dtype == QuantDtype.UINT:
zero_point_val = (qmin + qmax) // 2
zero_point = zero_point.new_full(zero_point.size(), zero_point_val)
else:
scale = (rmax - rmin) / float(qmax - qmin)
zero_point = qmin - torch.round(rmin / scale)
zero_point = torch.clamp(zero_point, qmin, qmax)
# todo: add these lines
# eps = torch.finfo(torch.float32).eps
# scale = torch.max(scale, eps)
return scale, zero_point
def update_ema(biased_ema, value, decay):
"""
calculate biased stat and unbiased stat in each step using exponential moving average method
Parameters
----------
biased_ema : float
previous stat value
value : float
current stat value
decay : float
the weight of previous stat value, larger means smoother curve
Returns
-------
float, float
"""
biased_ema = biased_ema * decay + (1 - decay) * value
return biased_ema
[文档]class QAT_Quantizer(Quantizer):
r"""
Quantizer defined in:
`Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
<http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf>`__
Authors Benoit Jacob and Skirmantas Kligys provide an algorithm to quantize the model with training.
..
We propose an approach that simulates quantization effects in the forward pass of training.
Backpropagation still happens as usual, and all weights and biases are stored in floating point
so that they can be easily nudged by small amounts.
The forward propagation pass however simulates quantized inference as it will happen in the inference engine,
by implementing in floating-point arithmetic the rounding behavior of the quantization scheme:
* Weights are quantized before they are convolved with the input. If batch normalization (see [17]) is used for the layer,
the batch normalization parameters are “folded into” the weights before quantization.
* Activations are quantized at points where they would be during inference,
e.g. after the activation function is applied to a convolutional or fully connected layer’s output,
or after a bypass connection adds or concatenates the outputs of several layers together such as in ResNets.
Parameters
----------
model : torch.nn.Module
Model to be quantized.
config_list : List[Dict]
List of configurations for quantization. Supported keys for dict:
- quant_types : List[str]
Type of quantization you want to apply, currently support 'weight', 'input', 'output'.
- quant_bits : Union[int, Dict[str, int]]
Bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
When the type is int, all quantization types share same bits length.
- quant_start_step : int
Disable quantization until model are run by certain number of steps, this allows the network to enter a more stable.
State where output quantization ranges do not exclude a significant fraction of values, default value is 0.
- op_types : List[str]
Types of nn.module you want to apply quantization, eg. 'Conv2d'.
- op_names : List[str]
Names of nn.module you want to apply quantization, eg. 'conv1'.
- exclude : bool
Set True then the layers setting by op_types and op_names will be excluded from quantization.
optimizer : torch.optim.Optimizer
Optimizer is required in `QAT_Quantizer`, NNI will patch the optimizer and count the optimize step number.
dummy_input : Tuple[torch.Tensor]
Inputs to the model, which are used to get the graph of the module. The graph is used to find Conv-Bn patterns.
And then the batch normalization folding would be enabled. If dummy_input is not given,
the batch normalization folding would be disabled.
Examples
--------
>>> from nni.compression.pytorch.quantization import QAT_Quantizer
>>> model = ...
>>> config_list = [{'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, 'op_types': ['Conv2d']}]
>>> optimizer = ...
>>> dummy_input = torch.rand(...)
>>> quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy_input)
>>> quantizer.compress()
>>> # Training Process...
For detailed example please refer to
:githublink:`examples/model_compress/quantization/QAT_torch_quantizer.py <examples/model_compress/quantization/QAT_torch_quantizer.py>`.
Notes
-----
**Batch normalization folding**
Batch normalization folding is supported in QAT quantizer. It can be easily enabled by passing an argument `dummy_input` to
the quantizer, like:
.. code-block:: python
# assume your model takes an input of shape (1, 1, 28, 28)
# and dummy_input must be on the same device as the model
dummy_input = torch.randn(1, 1, 28, 28)
# pass the dummy_input to the quantizer
quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy_input)
The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training
graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling
`quantizer.export_model`.
**Quantization dtype and scheme customization**
Different backends on different devices use different quantization strategies (i.e. dtype (int or uint) and
scheme (per-tensor or per-channel and symmetric or affine)). QAT quantizer supports customization of mainstream dtypes and schemes.
There are two ways to set them. One way is setting them globally through a function named `set_quant_scheme_dtype` like:
.. code-block:: python
from nni.compression.pytorch.quantization.settings import set_quant_scheme_dtype
# This will set all the quantization of 'input' in 'per_tensor_affine' and 'uint' manner
set_quant_scheme_dtype('input', 'per_tensor_affine', 'uint)
# This will set all the quantization of 'output' in 'per_tensor_symmetric' and 'int' manner
set_quant_scheme_dtype('output', 'per_tensor_symmetric', 'int')
# This will set all the quantization of 'weight' in 'per_channel_symmetric' and 'int' manner
set_quant_scheme_dtype('weight', 'per_channel_symmetric', 'int')
The other way is more detailed. You can customize the dtype and scheme in each quantization config list like:
.. code-block:: python
config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types':['Conv2d', 'Linear'],
'quant_dtype': 'int',
'quant_scheme': 'per_channel_symmetric'
}, {
'quant_types': ['output'],
'quant_bits': 8,
'quant_start_step': 7000,
'op_types':['ReLU6'],
'quant_dtype': 'uint',
'quant_scheme': 'per_tensor_affine'
}]
"""
def __init__(self, model, config_list, optimizer, dummy_input=None):
assert isinstance(optimizer, torch.optim.Optimizer), "unrecognized optimizer type"
super().__init__(model, config_list, optimizer, dummy_input)
self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device
self.bound_model.register_buffer("steps", torch.tensor(1))
for layer, config in modules_to_compress:
module = layer.module
name = layer.name
# TODO: may relax this limitation?
assert name in self.all_shapes, "Could not found shapes for layer {}".format(name)
input_shape, output_shape = self.all_shapes[name]
layer_quant_setting = LayerQuantSetting(config)
layer_quant_setting.ema_decay = 0.99
quant_start_step = config.get('quant_start_step', 0)
layer_quant_setting.quant_start_step = quant_start_step
# todo: support other ranks and remove this check
if isinstance(module, torch.nn.Linear):
if "input" in config.get("quant_types", []) and \
layer_quant_setting.input.quant_scheme in PER_CHANNEL_QUANT_SCHEME:
if len(input_shape) != 2:
logger.warning("When quantize torch.nn.Linear, make sure that the rank of the inputs "
"of the layer is 2. Skip quantization of layer %s.", name)
continue
if "output" in config.get("quant_types", []) and \
layer_quant_setting.output.quant_scheme in PER_CHANNEL_QUANT_SCHEME:
if len(output_shape) != 2:
logger.warning("When quantize torch.nn.Linear, make sure that the rank of the outputs "
"of the layer is 2. Skip quantization of layer %s.", name)
continue
if "weight" in config.get("quant_types", []):
quant_shape = get_quant_shape(module.weight.shape, QuantType.WEIGHT, layer_quant_setting.weight.quant_scheme)
module.register_buffer('weight_scale', torch.zeros(quant_shape))
module.register_buffer('weight_zero_point', torch.zeros(quant_shape))
if "input" in config.get("quant_types", []):
quant_shape = get_quant_shape(input_shape, QuantType.INPUT, layer_quant_setting.input.quant_scheme)
module.register_buffer('tracked_min_input', torch.zeros(quant_shape))
module.register_buffer('tracked_max_input', torch.zeros(quant_shape))
module.register_buffer('input_scale', torch.zeros(quant_shape))
module.register_buffer('input_zero_point', torch.zeros(quant_shape))
if "output" in config.get("quant_types", []):
quant_shape = get_quant_shape(output_shape, QuantType.OUTPUT, layer_quant_setting.output.quant_scheme)
module.register_buffer('tracked_min_output', torch.zeros(quant_shape))
module.register_buffer('tracked_max_output', torch.zeros(quant_shape))
module.register_buffer('output_scale', torch.zeros(quant_shape))
module.register_buffer('output_zero_point', torch.zeros(quant_shape))
setattr(module, "layer_quant_setting", layer_quant_setting)
self.bound_model.to(device)
def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_output', 'tracked_max_output',
'tracked_min_input', 'tracked_max_input', 'BN_FOLD_TAG',
'weight_scale', 'weight_zero_point', 'input_scale', 'input_zero_point',
'output_scale', 'output_zero_point', 'layer_quant_setting']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)
def validate_config(self, model, config_list):
"""
Parameters
----------
model : torch.nn.Module
Model to be pruned
config_list : list of dict
List of configurations
"""
SUPPORTED_OPS = ['Conv2d', 'Linear', 'ReLU', 'ReLU6']
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('input'): And(int, lambda n: 0 < n < 32),
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
Optional('quant_scheme'): Or(lambda x: x in QuantScheme, Schema({
Optional('input'): lambda x: x in QuantScheme,
Optional('weight'): lambda x: x in QuantScheme,
Optional('output'): lambda x: x in QuantScheme
})),
Optional('quant_dtype'): Or(lambda x: x in QuantDtype, Schema({
Optional('input'): lambda x: x in QuantDtype,
Optional('weight'): lambda x: x in QuantDtype,
Optional('output'): lambda x: x in QuantDtype
})),
Optional('quant_start_step'): And(int, lambda n: n >= 0),
Optional('op_types'): [And(str, lambda n: n in SUPPORTED_OPS)],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
def _quantize(self, real_value, scale, zero_point, qmin, qmax):
"""
quantize real value.
Parameters
----------
real_value : torch.Tensor
the real value to be quantized
scale : torch.Tensor
quantization scale
zero_point : torch.Tensor
quantization zero point
qmin : int
lower bound of the int range
qmax : int
upper bound of the int range
Returns
-------
Tensor
"""
transformed_val = zero_point + real_value / scale
clamped_val = torch.clamp(transformed_val, qmin, qmax)
quantized_val = torch.round(clamped_val)
return quantized_val
def _dequantize(self, quantized_val, scale, zero_point):
"""
dequantize quantized value.
Because we simulate quantization in training process, all the computations still happen as float point computations, which means we
first quantize tensors then dequantize them. For more details, please refer to the paper.
Parameters
----------
quantized_val : torch.Tensor
the quantized value to be de-quantized
scale : torch.Tensor
quantization scale
zero_point : torch.Tensor
quantization zero point
Returns
-------
Tensor
"""
real_val = scale * (quantized_val - zero_point)
return real_val
def quantize_weight(self, wrapper, **kwargs):
module = wrapper.module
weight = module.weight
layer_quant_setting = module.layer_quant_setting
tensor_quant_setting = layer_quant_setting.weight
# layer-wise settings
quant_start_step = layer_quant_setting.quant_start_step
# tensor-wise settings
dtype = tensor_quant_setting.quant_dtype
scheme = tensor_quant_setting.quant_scheme
qmin, qmax = tensor_quant_setting.get_qmin_qmax()
bits = tensor_quant_setting.bits
# In evaluation mode, we only quantize weight without updating statistics
if not wrapper.training:
scale, zero_point = module.weight_scale, module.weight_zero_point
weight = self._quantize(weight, scale, zero_point, qmin, qmax)
weight = self._dequantize(weight, scale, zero_point)
module.weight = weight
return weight
if quant_start_step > int(self.bound_model.steps):
return weight
current_min, current_max = get_min_max_value(weight, QuantType.WEIGHT, scheme)
scale, zero_point = update_quantization_param(bits, current_min, current_max, dtype, scheme)
module.weight_scale.copy_(scale)
module.weight_zero_point.copy_(zero_point)
weight = self._quantize(weight, scale, zero_point, qmin, qmax)
weight = self._dequantize(weight, scale, zero_point)
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper.module.weight = weight
return weight
def quantize_input(self, inputs, wrapper, **kwargs):
module = wrapper.module
layer_quant_setting = module.layer_quant_setting
tensor_quant_setting = layer_quant_setting.input
# layer-wise settings
quant_start_step = layer_quant_setting.quant_start_step
ema_decay = layer_quant_setting.ema_decay
# tensor-wise settings
dtype = tensor_quant_setting.quant_dtype
scheme = tensor_quant_setting.quant_scheme
qmin, qmax = tensor_quant_setting.get_qmin_qmax()
bits = tensor_quant_setting.bits
if not wrapper.training:
scale = module.input_scale
zero_point = module.input_zero_point
inputs = self._quantize(inputs, scale, zero_point, qmin, qmax)
inputs = self._dequantize(inputs, scale, zero_point)
return inputs
current_min, current_max = get_min_max_value(inputs, QuantType.INPUT, scheme)
if int(self.bound_model.steps) == 1:
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max)
tracked_min_input = update_ema(module.tracked_min_input, current_min, ema_decay)
tracked_max_input = update_ema(module.tracked_max_input, current_max, ema_decay)
module.tracked_min_input.copy_(tracked_min_input)
module.tracked_max_input.copy_(tracked_max_input)
if quant_start_step > int(self.bound_model.steps):
return inputs
scale, zero_point = update_quantization_param(
bits, module.tracked_min_input, module.tracked_max_input, dtype, scheme)
module.input_scale.copy_(scale)
module.input_zero_point.copy_(zero_point)
inputs = self._quantize(inputs, scale, zero_point, qmin, qmax)
inputs = self._dequantize(inputs, scale, zero_point)
return inputs
def quantize_output(self, output, wrapper, **kwargs):
module = wrapper.module
layer_quant_setting = module.layer_quant_setting
tensor_quant_setting = layer_quant_setting.output
# layer-wise settings
quant_start_step = layer_quant_setting.quant_start_step
ema_decay = layer_quant_setting.ema_decay
# tensor-wise settings
dtype = tensor_quant_setting.quant_dtype
scheme = tensor_quant_setting.quant_scheme
qmin, qmax = tensor_quant_setting.get_qmin_qmax()
bits = tensor_quant_setting.bits
if not wrapper.training:
scale = module.output_scale
zero_point = module.output_zero_point
output = self._quantize(output, scale, zero_point, qmin, qmax)
output = self._dequantize(output, scale, zero_point)
return output
current_min, current_max = get_min_max_value(output, QuantType.OUTPUT, scheme)
if int(self.bound_model.steps) == 1:
module.tracked_min_output.copy_(current_min)
module.tracked_max_output.copy_(current_max)
tracked_min_output = update_ema(module.tracked_min_output, current_min, ema_decay)
tracked_max_output = update_ema(module.tracked_max_output, current_max, ema_decay)
module.tracked_min_output.copy_(tracked_min_output)
module.tracked_max_output.copy_(tracked_max_output)
if quant_start_step > int(self.bound_model.steps):
return output
scale, zero_point = update_quantization_param(
bits, module.tracked_min_output, module.tracked_max_output, dtype, scheme)
module.output_scale.copy_(scale)
module.output_zero_point.copy_(zero_point)
output = self._quantize(output, scale, zero_point, qmin, qmax)
output = self._dequantize(output, scale, zero_point)
return output
def load_calibration_config(self, calibration_config):
modules_to_compress = self.get_modules_to_compress()
for layer, _ in modules_to_compress:
name, module = layer.name, layer.module
if name not in calibration_config:
if module.layer_quant_setting.weight or module.layer_quant_setting.input or module.layer_quant_setting.output:
logger.warning(f"Can not find module {name}'s parameter in input config.")
continue
if module.layer_quant_setting.weight:
assert calibration_config[name]['weight_bits'] == module.layer_quant_setting.weight.bits, \
f"weight bits of module {name} fail to match"
if module.layer_quant_setting.input:
assert calibration_config[name]['input_bits'] == module.layer_quant_setting.input.bits, \
f"input bits of module {name} fail to match"
module.tracked_min_input.data = torch.tensor([calibration_config[name]['tracked_min_input']])
module.tracked_max_input.data = torch.tensor([calibration_config[name]['tracked_max_input']])
if module.layer_quant_setting.output:
assert calibration_config[name]['output_bits'] == module.layer_quant_setting.output.bits, \
f"output bits of module {name} fail to match"
module.tracked_min_output.data = torch.tensor([calibration_config[name]['tracked_min_output']])
module.tracked_max_output.data = torch.tensor([calibration_config[name]['tracked_max_output']])
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}
modules_to_compress = self.get_modules_to_compress()
for layer, _ in modules_to_compress:
name, module = layer.name, layer.module
if hasattr(module.layer_quant_setting, 'weight') or hasattr(module.layer_quant_setting, 'output'):
calibration_config[name] = {}
if module.layer_quant_setting.weight:
calibration_config[name]['weight_bits'] = int(module.layer_quant_setting.weight.bits)
calibration_config[name]['weight_scale'] = module.weight_scale
calibration_config[name]['weight_zero_point'] = module.weight_zero_point
# Recover weight/bias for batch normalization folding
actual_weight = getattr(module, 'old_weight', None)
if actual_weight is None:
logger.warning("Can not recover weight for layer %s. "
"This may lead to a wrong accuracy performance on the backend.", name)
delattr(module, 'weight')
module.register_parameter('weight', actual_weight)
if hasattr(module, BN_FOLD_TAG):
actual_bias = getattr(module, 'old_bias', None)
delattr(module, 'bias')
if actual_bias is not None:
module.register_parameter('bias', actual_bias)
else:
setattr(module, 'bias', None)
if module.layer_quant_setting.input:
calibration_config[name]['input_bits'] = int(module.layer_quant_setting.input.bits)
calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input)
calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input)
if module.layer_quant_setting.output:
calibration_config[name]['output_bits'] = int(module.layer_quant_setting.output.bits)
calibration_config[name]['tracked_min_output'] = float(module.tracked_min_output)
calibration_config[name]['tracked_max_output'] = float(module.tracked_max_output)
self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device)
return calibration_config
def step_with_optimizer(self):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.bound_model.steps.add_(1)