Source code for nni.algorithms.compression.pytorch.quantization.observer_quantizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from collections import defaultdict
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer, QuantForward
from nni.compression.pytorch.quantization.observers import default_weight_observer, default_histogram_observer

logger = logging.getLogger(__name__)

[docs]class ObserverQuantizer(Quantizer): r""" Observer quantizer is a framework of post-training quantization. It will insert observers into the place where the quantization will happen. During quantization calibration, each observer will record all the tensors it 'sees'. These tensors will be used to calculate the quantization statistics after calibration. The whole process can be divided into three steps: 1. It will register observers to the place where quantization would happen (just like registering hooks). 2. The observers would record tensors' statistics during calibration. 3. Scale & zero point would be obtained after calibration. Note that the observer type, tensor dtype and quantization qscheme are hard coded for now. Their customization are under development and will be ready soon. Parameters ---------- model : torch.nn.Module Model to be quantized. config_list : List[Dict] List of configurations for quantization. Supported keys: - quant_types : List[str] Type of quantization you want to apply, currently support 'weight', 'input', 'output'. - quant_bits : Union[int, Dict[str, int]] Bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, when the type is int, all quantization types share same bits length. - op_types : List[str] Types of nn.module you want to apply quantization, eg. 'Conv2d'. - op_names : List[str] Names of nn.module you want to apply quantization, eg. 'conv1'. - exclude : bool Set True then the layers setting by op_types and op_names will be excluded from quantization. optimizer : torch.optim.Optimizer Optimizer is optional in `ObserverQuantizer`. Examples -------- >>> from nni.algorithms.compression.pytorch.quantization import ObserverQuantizer >>> model = ... >>> config_list = [{'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, 'op_types': ['Conv2d']}] >>> quantizer = ObserverQuantizer(model, config_list) >>> # define a calibration function >>> def calibration(model, calib_loader): >>> model.eval() >>> with torch.no_grad(): >>> for data, _ in calib_loader: >>> model(data) >>> calibration(model, calib_loader) >>> quantizer.compress() For detailed example please refer to :githublink:`examples/model_compress/quantization/ <examples/model_compress/quantization/>`. .. note:: This quantizer is still under development for now. Some quantizer settings are hard-coded: - weight observer: per_tensor_symmetric, qint8 - output observer: per_tensor_affine, quint8, reduce_range=True Other settings (such as quant_type and op_names) can be configured. Notes ----- **About the compress API** Before the `compress` API is called, the model will only record tensors' statistics and no quantization process will be executed. After the `compress` API is called, the model will NOT record tensors' statistics any more. The quantization scale and zero point will be generated for each tensor and will be used to quantize each tensor during inference (we call it evaluation mode) **About calibration** Usually we pick up about 100 training/evaluation examples for calibration. If you found the accuracy is a bit low, try to reduce the number of calibration examples. """ def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) # NOTE: this quantizer is experimental for now. The dtype and qscheme of quantization # is hard-coded. # TODO: # 1. support dtype and qscheme customization through config_list. Current settings: # weight observer : per_tensor_symmetric, qint8 # output observer : per_tensor_affine, quint8, reduce_range=True # 2. add more kinds of observers, such as Kullback-Leibler divergence. # 3. add batch normalization folding assert not, "Currently the observer quantizer only works in evaluation mode." self.quant_grad = QuantForward() self.device = next(model.parameters()).device modules_to_compress = self.get_modules_to_compress() all_observers = defaultdict(dict) weight_qmin, weight_qmax = -127, 127 output_qmin, output_qmax = 0, 127 # reduce_range is set to True self.compressed = False for layer, config in modules_to_compress: layer_name = module = layer.module if "weight" in config.get("quant_types", []): all_observers[layer_name]["weight"] = default_weight_observer() setattr(module, "weight_qmax", weight_qmax) setattr(module, "weight_qmin", weight_qmin) if "input" in config.get("quant_types", []): all_observers[layer_name]["input"] = default_histogram_observer() setattr(module, "input_qmax", output_qmax) setattr(module, "input_qmin", output_qmin) if "output" in config.get("quant_types", []): all_observers[layer_name]["output"] = default_histogram_observer() setattr(module, "output_qmax", output_qmax) setattr(module, "output_qmin", output_qmin) self.all_observers = all_observers def validate_config(self, model, config_list): schema = QuantizerSchema([{ Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]), Optional('quant_bits'): Or(And(int, lambda n: n == 8), Schema({ Optional('weight'): And(int, lambda n: n == 8), Optional('output'): And(int, lambda n: n == 8), Optional('input'): And(int, lambda n: n == 8), })), Optional('op_types'): [str], Optional('op_names'): [str] }], model, logger) schema.validate(config_list) def record(self, wrapper, quant_type, tensor): name = observer = self.all_observers[name][quant_type] observer(tensor.cpu()) def calculate_qparams(self, name, quant_type): observer = self.all_observers[name][quant_type] scale, zero_point = observer.calculate_qparams() return scale, zero_point def _quantize(self, x, scale, zero_point, qmin, qmax): x = x / scale + zero_point x = torch.clamp(x, qmin, qmax) x = torch.round(x) x = (x - zero_point) * scale return x def quantize_input(self, inputs, wrapper, **kwargs): if self.compressed: module = wrapper.module inputs = self._quantize(inputs, module.input_scale, module.input_zero_point, module.input_qmin, module.input_qmax) else: self.record(wrapper, 'input', inputs) return inputs def quantize_weight(self, wrapper, **kwargs): # If ObserverQuantizer.compress is executed, the weight will be set to # the Pseudo-quantized one. So there is no need to quantize it if self.compressed: return weight = wrapper.module.weight self.record(wrapper, 'weight', weight) def quantize_output(self, output, wrapper, **kwargs): if self.compressed: module = wrapper.module new_output = self._quantize(output, module.output_scale, module.output_zero_point, module.output_qmin, module.output_qmax) else: self.record(wrapper, 'output', output) new_output = output return new_output def compress(self): """ Calculate quantization information of each tensor. Note that the inference of the compressed model will no longer update the corresponding. Instead, the quantization process will be simulated, which is used to test the accuracy of the quantization. """ modules_to_compress = self.get_modules_to_compress() for layer, config in modules_to_compress: module = layer.module if "weight" in config.get("quant_types", []): scale, zero_point = self.calculate_qparams(, 'weight') module.register_buffer('weight_scale', module.register_buffer('weight_zero_point', weight = module.weight quantized_weight = self._quantize(weight, module.weight_scale, module.weight_zero_point, module.weight_qmin, module.weight_qmax) delattr(module, 'weight') module.register_buffer('weight', quantized_weight) if "input" in config.get("quant_types", []): scale, zero_point = self.calculate_qparams(, 'input') module.register_buffer('input_scale', module.register_buffer('input_zero_point', if "output" in config.get("quant_types", []): scale, zero_point = self.calculate_qparams(, 'output') module.register_buffer('output_scale', module.register_buffer('output_zero_point', self.compressed = True super().compress() def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): """ Export quantized model weights and calibration parameters(optional) Parameters ---------- model_path : str path to save quantized model weight calibration_path : str (optional) path to save quantize parameters after calibration onnx_path : str (optional) path to save onnx model input_shape : list or tuple input shape to onnx model device : torch.device device of the model, used to place the dummy input tensor for exporting onnx file. the tensor is placed on cpu if ```device``` is None Returns ------- Dict """ assert model_path is not None, 'model_path must be specified' self._unwrap_model() calibration_config = {} for name, module in self.bound_model.named_modules(): if hasattr(module, 'weight_scale') or hasattr(module, 'input_scale') or hasattr(module, 'output_scale'): calibration_config[name] = {} if hasattr(module, 'weight_scale'): calibration_config[name]['weight_bits'] = 8 val = float(module.weight_scale * module.weight_qmax) calibration_config[name]['tracked_max_weight'] = val calibration_config[name]['tracked_min_weight'] = -val calibration_config[name]['tracked_qmin_weight'] = -127 calibration_config[name]['tracked_qmax_weight'] = 127 weight = module.weight quantized_weight = self._quantize(weight, module.weight_scale, module.weight_zero_point, module.weight_qmin, module.weight_qmax) delattr(module, 'weight') module.register_parameter('weight', torch.nn.Parameter(quantized_weight)) # refactor these magic numbers when customizations of dtype and qscheme are ready. if hasattr(module, 'input_scale'): calibration_config[name]['input_bits'] = 8 max_input = float(module.input_scale * (module.input_qmax - module.input_zero_point)) min_input = float(module.input_scale * (module.input_qmin - module.input_zero_point)) calibration_config[name]['tracked_min_input'] = min_input calibration_config[name]['tracked_max_input'] = max_input calibration_config[name]['tracked_qmin_input'] = 0 calibration_config[name]['tracked_qmax_input'] = 127 if hasattr(module, 'output_scale'): calibration_config[name]['output_bits'] = 8 max_input = float(module.output_scale * (module.output_qmax - module.output_zero_point)) min_input = float(module.output_scale * (module.output_qmin - module.output_zero_point)) calibration_config[name]['tracked_min_output'] = min_input calibration_config[name]['tracked_max_output'] = max_input calibration_config[name]['tracked_qmin_output'] = 0 calibration_config[name]['tracked_qmax_output'] = 127 self._del_simulated_attr(module) self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) return calibration_config def _del_simulated_attr(self, module): """ delete redundant parameters in quantize module """ del_attr_list = ['old_weight', 'steps', 'weight_qmax', 'weight_qmin', 'input_qmax', 'input_qmin', 'output_qmax', 'output_qmin', 'weight_scale', 'weight_zero_point', 'input_scale', 'input_zero_point', 'output_scale', 'output_zero_point'] for attr in del_attr_list: if hasattr(module, attr): delattr(module, attr)