Source code for nni.compression.quantization.ptq_quantizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations
from typing import List, Dict, Union, overload

import torch
from torch import Tensor

from ..base.compressor import Compressor, Quantizer
from ..base.wrapper import ModuleWrapper
from ..base.target_space import QuantizationTargetSpace
from ..utils import Evaluator, _EVALUATOR_DOCSTRING


[docs] class PtqQuantizer(Quantizer): __doc__ = r''' Post Training Quantization Parameters ---------- model Model to be quantized. config_list A list of dict, each dict configure which module need to be quantized, and how to quantize. Please refer :doc:`Compression Config Specification </compression/config_list>` for more information. evaluator {evaluator_docstring} Examples -------- >>> from nni.compression.quantization import PtqQuantizer >>> from nni.compression.utils import TorchEvaluator >>> model = ... >>> optimizer = ... >>> max_steps, max_epochs = ..., ... >>> evaluator = TorchEvaluator(train, optimizer, training_step) >>> quantizer = PtqQuantizer(model, configure_list, evaluator) >>> _, calibration_config = quantizer.compress(max_steps, max_epochs) '''.format(evaluator_docstring=_EVALUATOR_DOCSTRING) @overload def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator): ... @overload def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, existed_wrappers: Dict[str, ModuleWrapper] | None = None): ... def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, \ existed_wrappers: Dict[str, ModuleWrapper] | None = None, is_bias_correction: bool = False): super().__init__(model, config_list, evaluator, existed_wrappers) self.evaluator: Evaluator self.is_compressed = False self.is_bias_correction = is_bias_correction self.register_ptq_apply_method() self.register_track_func() @classmethod def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None): return super().from_compressor(compressor, new_config_list, evaluator=evaluator) def register_ptq_apply_method(self): for _, ts in self._target_spaces.items(): for _, target_space in ts.items(): target_space.apply_method = 'clamp_round' if self.is_compressed else 'bypass' def register_track_func(self): for module_name, _ in self._target_spaces.items(): wrapper = self._module_wrappers[module_name] wrapper.register_track_func(self.track_min_max_val) # wrapper.register_track_func(self.update_scale_zp_for_bias_correction) def track_min_max_val(self, wrapper: ModuleWrapper, target_name: str, target: Tensor): def amin_reduce_func(converted_target: Tensor): return converted_target.detach().amin(dim=-1) def amax_reduce_func(converted_target: Tensor): return converted_target.detach().amax(dim=-1) if target_name not in wrapper.quantization_target_spaces: return target_space = wrapper.quantization_target_spaces[target_name] # TODO sync the collection of data info when using ddp if target_space._scaler: current_amin = target_space._scaler.shrink(target, amin_reduce_func, keepdim=True) current_amax = target_space._scaler.shrink(target, amax_reduce_func, keepdim=True) else: current_amin = target.detach().reshape(-1).amin(-1) current_amax = target.detach().reshape(-1).amax(-1) # update target_space.tracked_max = update_tracked_value(target_space.tracked_max, current_amax, "max") target_space.tracked_min = update_tracked_value(target_space.tracked_min, current_amin, "min") def update_scale_zp(self): for _, ts in self._target_spaces.items(): for _, target_space in ts.items(): scale, zero_point = compute_scale_zp(target_space) # type: ignore target_space.scale, target_space.zero_point = scale, zero_point def _single_compress(self, max_steps: int | None, max_epochs: int | None): self._fusion_compress(max_steps, max_epochs) def _fuse_preprocess(self, evaluator: Evaluator) -> None: module_name_param_dict = self.patch_optimizer_param_group() if len(module_name_param_dict) > 0: evaluator.patch_optim_param_group(module_name_param_dict) def _fuse_postprocess(self, evaluator: Evaluator) -> None: self.evaluator.evaluate() # compute and update scale and zero self.update_scale_zp() self.is_compressed = True self.register_ptq_apply_method() # bias correction if self.is_bias_correction: self.bias_correction() def bias_correction(self): assert self.is_bias_correction, \ f"is_bias_correction should be True, but got {self.is_bias_correction}" for module_name, _ in self._target_spaces.items(): wrapper = self._module_wrappers[module_name] setattr(wrapper, "is_bias_correction", self.is_bias_correction) # running bias correction process # TODO: add an warning for user to change evaluation dataset self.evaluator.evaluate() for module_name, _ in self._target_spaces.items(): wrapper = self._module_wrappers[module_name] wrapper.update_bias() delattr(wrapper, "is_bias_correction") delattr(wrapper, "bias_correction") delattr(wrapper, "bias_element_num") self.evaluator.evaluate() self.update_scale_zp()
def compute_scale_zp(target_space: QuantizationTargetSpace): if target_space.tracked_max is None or target_space.tracked_min is None: return tracked_min = torch.min(target_space.tracked_min, torch.zeros_like(target_space.tracked_min)) tracked_max = torch.max(target_space.tracked_max, torch.zeros_like(target_space.tracked_max)) zero_point = torch.zeros_like(tracked_min) if target_space.quant_scheme in ['symmetric', None]: abs_max = torch.max(torch.abs(tracked_min), torch.abs(tracked_max)) scale = abs_max / (float(target_space.qmax - target_space.qmin) / 2) scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps)) # NOTE: here need to check, +1 because in pytorch, symmetric qint8 zp is 0, quint8 zp is 128. zero_point_val = (target_space.qmax + target_space.qmin + 1) // 2 zero_point = torch.full_like(zero_point, zero_point_val) elif target_space.quant_scheme == 'affine': scale = (tracked_max - tracked_min) / float(target_space.qmax - target_space.qmin) scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps)) zero_point = target_space.qmin - torch.round(tracked_min / scale) else: raise RuntimeError(f'Unknown quant_scheme {target_space.quant_scheme}') zero_point = torch.clamp(zero_point, target_space.qmin, target_space.qmax) return scale, zero_point def update_tracked_value(original_val: Union[Tensor, None], current_val: Tensor, mode: str="max"): if original_val is None: return current_val assert current_val is not None assert original_val.shape == current_val.shape if mode == "max": return torch.max(original_val, current_val) elif mode == "min": return torch.min(original_val, current_val) else: raise TypeError(f"Type:{mode} is not supported")