# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import defaultdict
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation_v1 import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer, QuantForward
from ..utils.quantization.observers import default_weight_observer, default_histogram_observer
logger = logging.getLogger(__name__)
[文档]class ObserverQuantizer(Quantizer):
r"""
Observer quantizer is a framework of post-training quantization.
It will insert observers into the place where the quantization will happen.
During quantization calibration, each observer will record all the tensors it 'sees'.
These tensors will be used to calculate the quantization statistics after calibration.
The whole process can be divided into three steps:
1. It will register observers to the place where quantization would happen (just like registering hooks).
2. The observers would record tensors' statistics during calibration.
3. Scale & zero point would be obtained after calibration.
Note that the observer type, tensor dtype and quantization qscheme are hard coded for now. Their customization
are under development and will be ready soon.
Parameters
----------
model : torch.nn.Module
Model to be quantized.
config_list : List[Dict]
List of configurations for quantization. Supported keys:
- quant_types : List[str]
Type of quantization you want to apply, currently support 'weight', 'input', 'output'.
- quant_bits : Union[int, Dict[str, int]]
Bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length.
- op_types : List[str]
Types of nn.module you want to apply quantization, eg. 'Conv2d'.
- op_names : List[str]
Names of nn.module you want to apply quantization, eg. 'conv1'.
- exclude : bool
Set True then the layers setting by op_types and op_names will be excluded from quantization.
optimizer : torch.optim.Optimizer
Optimizer is optional in `ObserverQuantizer`.
Examples
--------
>>> from nni.compression.pytorch.quantization import ObserverQuantizer
>>> model = ...
>>> config_list = [{'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, 'op_types': ['Conv2d']}]
>>> quantizer = ObserverQuantizer(model, config_list)
>>> # define a calibration function
>>> def calibration(model, calib_loader):
>>> model.eval()
>>> with torch.no_grad():
>>> for data, _ in calib_loader:
>>> model(data)
>>> calibration(model, calib_loader)
>>> quantizer.compress()
For detailed example please refer to
:githublink:`examples/model_compress/quantization/observer_quantizer.py <examples/model_compress/quantization/observer_quantizer.py>`.
.. note::
This quantizer is still under development for now. Some quantizer settings are hard-coded:
- weight observer: per_tensor_symmetric, qint8
- output observer: per_tensor_affine, quint8, reduce_range=True
Other settings (such as quant_type and op_names) can be configured.
Notes
-----
**About the compress API**
Before the `compress` API is called, the model will only record tensors' statistics and no quantization process will be executed.
After the `compress` API is called, the model will NOT record tensors' statistics any more. The quantization scale and zero point will
be generated for each tensor and will be used to quantize each tensor during inference (we call it evaluation mode)
**About calibration**
Usually we pick up about 100 training/evaluation examples for calibration. If you found the accuracy is a bit low, try
to reduce the number of calibration examples.
"""
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
# NOTE: this quantizer is experimental for now. The dtype and qscheme of quantization
# is hard-coded.
# TODO:
# 1. support dtype and qscheme customization through config_list. Current settings:
# weight observer : per_tensor_symmetric, qint8
# output observer : per_tensor_affine, quint8, reduce_range=True
# 2. add more kinds of observers, such as Kullback-Leibler divergence.
# 3. add batch normalization folding
assert not model.training, "Currently the observer quantizer only works in evaluation mode."
self.quant_grad = QuantForward()
self.device = next(model.parameters()).device
modules_to_compress = self.get_modules_to_compress()
all_observers = defaultdict(dict)
weight_qmin, weight_qmax = -127, 127
output_qmin, output_qmax = 0, 127 # reduce_range is set to True
self.compressed = False
for layer, config in modules_to_compress:
layer_name = layer.name
module = layer.module
if "weight" in config.get("quant_types", []):
all_observers[layer_name]["weight"] = default_weight_observer()
setattr(module, "weight_qmax", weight_qmax)
setattr(module, "weight_qmin", weight_qmin)
if "input" in config.get("quant_types", []):
all_observers[layer_name]["input"] = default_histogram_observer()
setattr(module, "input_qmax", output_qmax)
setattr(module, "input_qmin", output_qmin)
if "output" in config.get("quant_types", []):
all_observers[layer_name]["output"] = default_histogram_observer()
setattr(module, "output_qmax", output_qmax)
setattr(module, "output_qmin", output_qmin)
self.all_observers = all_observers
self.bound_model.to(self.device)
def validate_config(self, model, config_list):
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]),
Optional('quant_bits'): Or(And(int, lambda n: n == 8), Schema({
Optional('weight'): And(int, lambda n: n == 8),
Optional('output'): And(int, lambda n: n == 8),
Optional('input'): And(int, lambda n: n == 8),
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
}], model, logger)
schema.validate(config_list)
def record(self, wrapper, quant_type, tensor):
name = wrapper.name
observer = self.all_observers[name][quant_type]
observer(tensor.cpu())
def calculate_qparams(self, name, quant_type):
observer = self.all_observers[name][quant_type]
scale, zero_point = observer.calculate_qparams()
return scale, zero_point
def _quantize(self, x, scale, zero_point, qmin, qmax):
x = x / scale + zero_point
x = torch.clamp(x, qmin, qmax)
x = torch.round(x)
x = (x - zero_point) * scale
return x
def quantize_input(self, inputs, wrapper, **kwargs):
if self.compressed:
module = wrapper.module
inputs = self._quantize(inputs,
module.input_scale,
module.input_zero_point,
module.input_qmin,
module.input_qmax)
else:
self.record(wrapper, 'input', inputs)
return inputs
def quantize_weight(self, wrapper, **kwargs):
# If ObserverQuantizer.compress is executed, the weight will be set to
# the Pseudo-quantized one. So there is no need to quantize it
if self.compressed:
return
weight = wrapper.module.weight
self.record(wrapper, 'weight', weight)
def quantize_output(self, output, wrapper, **kwargs):
if self.compressed:
module = wrapper.module
new_output = self._quantize(output,
module.output_scale,
module.output_zero_point,
module.output_qmin,
module.output_qmax)
else:
self.record(wrapper, 'output', output)
new_output = output
return new_output
def compress(self):
"""
Calculate quantization information of each tensor. Note that the inference of
the compressed model will no longer update the corresponding. Instead, the quantization
process will be simulated, which is used to test the accuracy of the quantization.
"""
modules_to_compress = self.get_modules_to_compress()
for layer, config in modules_to_compress:
module = layer.module
if "weight" in config.get("quant_types", []):
scale, zero_point = self.calculate_qparams(layer.name, 'weight')
module.register_buffer('weight_scale', scale.to(self.device))
module.register_buffer('weight_zero_point', zero_point.to(self.device))
weight = module.weight
quantized_weight = self._quantize(weight,
module.weight_scale,
module.weight_zero_point,
module.weight_qmin,
module.weight_qmax)
delattr(module, 'weight')
module.register_buffer('weight', quantized_weight)
if "input" in config.get("quant_types", []):
scale, zero_point = self.calculate_qparams(layer.name, 'input')
module.register_buffer('input_scale', scale.to(self.device))
module.register_buffer('input_zero_point', zero_point.to(self.device))
if "output" in config.get("quant_types", []):
scale, zero_point = self.calculate_qparams(layer.name, 'output')
module.register_buffer('output_scale', scale.to(self.device))
module.register_buffer('output_zero_point', zero_point.to(self.device))
self.compressed = True
super().compress()
def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None):
"""
Export quantized model weights and calibration parameters(optional)
Parameters
----------
model_path : str
path to save quantized model weight
calibration_path : str
(optional) path to save quantize parameters after calibration
onnx_path : str
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
Returns
-------
Dict
"""
assert model_path is not None, 'model_path must be specified'
self._unwrap_model()
calibration_config = {}
for name, module in self.bound_model.named_modules():
if hasattr(module, 'weight_scale') or hasattr(module, 'input_scale') or hasattr(module, 'output_scale'):
calibration_config[name] = {}
if hasattr(module, 'weight_scale'):
calibration_config[name]['weight_bits'] = 8
val = float(module.weight_scale * module.weight_qmax)
calibration_config[name]['tracked_max_weight'] = val
calibration_config[name]['tracked_min_weight'] = -val
calibration_config[name]['tracked_qmin_weight'] = -127
calibration_config[name]['tracked_qmax_weight'] = 127
weight = module.weight
quantized_weight = self._quantize(weight,
module.weight_scale,
module.weight_zero_point,
module.weight_qmin,
module.weight_qmax)
delattr(module, 'weight')
module.register_parameter('weight', torch.nn.Parameter(quantized_weight))
# refactor these magic numbers when customizations of dtype and qscheme are ready.
if hasattr(module, 'input_scale'):
calibration_config[name]['input_bits'] = 8
max_input = float(module.input_scale * (module.input_qmax - module.input_zero_point))
min_input = float(module.input_scale * (module.input_qmin - module.input_zero_point))
calibration_config[name]['tracked_min_input'] = min_input
calibration_config[name]['tracked_max_input'] = max_input
calibration_config[name]['tracked_qmin_input'] = 0
calibration_config[name]['tracked_qmax_input'] = 127
if hasattr(module, 'output_scale'):
calibration_config[name]['output_bits'] = 8
max_input = float(module.output_scale * (module.output_qmax - module.output_zero_point))
min_input = float(module.output_scale * (module.output_qmin - module.output_zero_point))
calibration_config[name]['tracked_min_output'] = min_input
calibration_config[name]['tracked_max_output'] = max_input
calibration_config[name]['tracked_qmin_output'] = 0
calibration_config[name]['tracked_qmax_output'] = 127
self._del_simulated_attr(module)
self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path,
input_shape, device)
return calibration_config
def _del_simulated_attr(self, module):
"""
delete redundant parameters in quantize module
"""
del_attr_list = ['old_weight', 'steps', 'weight_qmax', 'weight_qmin', 'input_qmax', 'input_qmin',
'output_qmax', 'output_qmin', 'weight_scale', 'weight_zero_point', 'input_scale',
'input_zero_point', 'output_scale', 'output_zero_point']
for attr in del_attr_list:
if hasattr(module, attr):
delattr(module, attr)