Source code for nni.compression.quantization.dorefa_quantizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import annotations
import logging
from typing import List, Dict, Union, overload

import torch
import torch.nn as nn
from torch import Tensor

from nni.common.version import torch_version_is_2

from ..base.compressor import Compressor, Quantizer
from ..base.wrapper import ModuleWrapper
from ..utils import Evaluator, _EVALUATOR_DOCSTRING
from ..base.target_space import TargetType, QuantizationTargetSpace


ACTIVATION_LIST = [
    nn.ReLU, nn.RReLU, nn.LeakyReLU, nn.PReLU, nn.Softplus, nn.ELU, nn.CELU, nn.SELU, nn.GELU,
    nn.ReLU6, nn.Sigmoid, nn.Tanh, nn.Softsign, nn.Hardtanh, nn.Threshold, nn.Tanhshrink,
    nn.Softshrink, nn.Hardshrink, nn.LogSigmoid, nn.Softmin, nn.Softmax, nn.LogSoftmax, nn.Hardswish,
]

_logger = logging.getLogger(__name__)
is_proper_torch_version = torch_version_is_2()


[docs] class DoReFaQuantizer(Quantizer): __doc__ = r''' Dorefa-Quantizer, as defined in: `DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients <https://arxiv.org/abs/1606.06160>`__, authors Shuchang Zhou and Yuxin Wu provide an algorithm named DoReFa to quantize the weight, activation and gradients with training. Parameters ---------- model Model to be quantized. config_list A list of dict, each dict configure which module need to be quantized, and how to quantize. Please refer :doc:`Compression Config Specification </compression/config_list>` for more information. evaluator {evaluator_docstring} Examples -------- >>> from nni.compression.quantization import DoReFaQuantizer >>> from nni.compression.utils import TorchEvaluator >>> model = ... >>> optimizer = ... >>> max_steps, max_epochs = ..., ... >>> evaluator = TorchEvaluator(train, optimizer, training_step) >>> quantizer = DoReFaQuantizer(model, configure_list, evaluator) >>> _, calibration_config = quantizer.compress(max_steps, max_epochs) '''.format(evaluator_docstring=_EVALUATOR_DOCSTRING) @overload def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator): ... @overload def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, existed_wrappers: Dict[str, ModuleWrapper] | None = None): ... def __init__(self, model: torch.nn.Module, config_list: List[Dict], evaluator: Evaluator, \ existed_wrappers: Dict[str, ModuleWrapper] | None = None): super().__init__(model, config_list, evaluator, existed_wrappers=existed_wrappers) self.evaluator: Evaluator self.is_init = False self.check_validation() self.register_dorefa_apply_method() self.register_track_func() @classmethod def from_compressor(cls, compressor: Compressor, new_config_list: List[Dict], evaluator: Evaluator | None = None): return super().from_compressor(compressor, new_config_list, evaluator=evaluator) def check_validation(self) -> None: for ts in self._target_spaces.values(): for target_space in ts.values(): assert target_space.quant_scheme != None if target_space.type is TargetType.PARAMETER and target_space.quant_scheme != 'affine': warn_msg = f'Only supports affine mode for weight quantization, bug got {target_space.quant_scheme}' _logger.warning(warn_msg) elif target_space.type is TargetType.OUTPUT: module = target_space._wrapper.module # case 1: activation module # case 2: module with activation fused_modules fused_modules = target_space._wrapper.fused_modules if not isinstance(module, tuple(ACTIVATION_LIST)) and not (fused_modules and # type: ignore any([isinstance(item, tuple(ACTIVATION_LIST)) for item in fused_modules[1:]])): # type: ignore raise ValueError('Output quantization is only supported for activation function or' + \ f'activation module fusion, but got {type(module)}') if target_space.quant_scheme != 'affine': warn_msg = f'Only supports affine mode for output quantization, bug got {target_space.quant_scheme}' _logger.warning(warn_msg) if target_space._scaler is not None: raise ValueError('DoRefa Qauntizer doesn\'t support for granularity, please set it to False') def _quant_dequant_gradient_hook(self, target_space: QuantizationTargetSpace) -> None: def quant_dequant_gradient(module: nn.Module, grad_output): tracked_max = torch.tensor(1.0 + 0.5 / (2**target_space.quant_bits - 1)).to(grad_output[0].device) tracked_min = torch.tensor(0 - 0.5 / (2**target_space.quant_bits - 1)).to(grad_output[0].device) scale, zero_point = init_scale_zp(tracked_max, tracked_min, target_space.qmax, \ target_space.qmin, 'affine') new_grad_output = [] for g_o in grad_output: grad_o = torch.abs(g_o.clone().detach()) dim_lis = list(range(len(grad_o.shape))) dim_lis.pop(0) max_grad = torch.amax(grad_o, dim=dim_lis, keepdim=True) # generate uniform noise uniform_k = torch.zeros_like(max_grad).to(g_o.device) N_k = uniform_k.uniform_(-0.5, 0.5) / (2**(target_space.quant_bits) - 1) q_grad_o = g_o / (2 * max_grad) + 0.5 + N_k quantized_grad = zero_point + q_grad_o / scale quantized_grad = torch.round(torch.clamp(quantized_grad, target_space.qmin, target_space.qmax)) dequantized_grad = (quantized_grad - zero_point) * scale new_grad_output.append((dequantized_grad - 0.5) * 2 * max_grad) return tuple(new_grad_output) target_space._wrapper.module.register_full_backward_pre_hook(quant_dequant_gradient) # type: ignore def register_output_backward_hook(self): for ts in self._target_spaces.values(): is_output = any([target_space.type is TargetType.OUTPUT for target_space in ts.values()]) is_param = any([target_space.type is TargetType.PARAMETER for target_space in ts.values()]) if is_param and not is_output: if is_proper_torch_version: # torch version >= 2.0.0 for target_space in ts.values(): if target_space.type is TargetType.PARAMETER: self._quant_dequant_gradient_hook(target_space) break else: warn_msg = f'Gradient quantization is only supported for torch version >= 2.0.0' _logger.warning(warn_msg) def register_dorefa_apply_method(self): for _, ts in self._target_spaces.items(): for _, target_space in ts.items(): if target_space.type is TargetType.PARAMETER: target_space.apply_method = 'dorefa_clamp_round_weight' elif target_space.type is TargetType.INPUT: target_space.apply_method = 'clamp_round' elif target_space.type is TargetType.OUTPUT: target_space.apply_method = 'dorefa_clamp_round_output' def register_track_func(self): for module_name, _ in self._target_spaces.items(): wrapper = self._module_wrappers[module_name] wrapper.register_track_func(self.initialize_scale_zp) wrapper.register_track_func(self.update_scale_zp) def update_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor) -> None: if not self.check_target(wrapper, target_name): return target_space = wrapper.quantization_target_spaces[target_name] if target_space.type is not TargetType.INPUT: return # track min max values current_amin = target.detach().reshape(-1).amin(-1) current_amax = target.detach().reshape(-1).amax(-1) # update scale and zero_point tracked_min = torch.min(current_amin, torch.zeros_like(current_amin)) tracked_max = torch.max(current_amax, torch.zeros_like(current_amax)) zero_point = torch.zeros_like(tracked_min) qmin, qmax = target_space.qmin, target_space.qmax assert isinstance(qmin, int) and isinstance(qmax, int) if target_space.quant_scheme in ['symmetric', None]: abs_max = torch.max(torch.abs(tracked_min), torch.abs(tracked_max)) scale = abs_max / (float(qmax - qmin) / 2) scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps)) # NOTE: here need to check, +1 because in pytorch, symmetric qint8 zp is 0, quint8 zp is 128. zero_point_val = (qmax + qmin + 1) // 2 zero_point = torch.full_like(zero_point, zero_point_val) elif target_space.quant_scheme == 'affine': scale = (tracked_max - tracked_min) / float(qmax - qmin) scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps)) zero_point = qmin - torch.round(tracked_min / scale) else: raise RuntimeError(f'Unknown quant_scheme {target_space.quant_scheme}') zero_point = torch.clamp(zero_point, qmin, qmax) target_space.scale, target_space.zero_point = scale, zero_point def initialize_scale_zp(self, wrapper: ModuleWrapper, target_name: str, target: Tensor): if self.is_init or not self.check_target(wrapper, target_name): return target_space = wrapper.quantization_target_spaces[target_name] if target_space.type is TargetType.INPUT: return elif target_space.type in [TargetType.OUTPUT, TargetType.PARAMETER]: tracked_max = torch.tensor(1.0).to(target.device) tracked_min = torch.tensor(0.0).to(target.device) scale, zero_point = init_scale_zp(tracked_max, tracked_min, target_space.qmax, \ target_space.qmin, 'affine') else: raise RuntimeError(f'Unknown target_name {target_name}') target_space.scale, target_space.zero_point = scale, zero_point def register_trigger(self, evaluator: Evaluator): def optimizer_task(): self.is_init = True evaluator.patch_optimizer_step(before_step_tasks=[], after_step_tasks=[optimizer_task]) def _single_compress(self, max_steps: int | None, max_epochs: int | None): self._fusion_compress(max_steps, max_epochs) def _fuse_preprocess(self, evaluator: Evaluator) -> None: self.register_output_backward_hook() module_name_param_dict = self.patch_optimizer_param_group() if len(module_name_param_dict) > 0: evaluator.patch_optim_param_group(module_name_param_dict) self.register_trigger(evaluator) def _fuse_postprocess(self, evaluator: Evaluator) -> None: pass
def init_scale_zp(tracked_max: Tensor, tracked_min: Tensor, qmax: int, qmin: int, quant_scheme: Union[str, None] = None): tracked_min = torch.min(tracked_min, torch.zeros_like(tracked_min)) tracked_max = torch.max(tracked_max, torch.zeros_like(tracked_max)) zero_point = torch.zeros_like(tracked_min) if quant_scheme == 'affine': scale = (tracked_max - tracked_min) / float(qmax - qmin) scale = torch.max(scale, torch.full_like(scale, torch.finfo(torch.float32).eps)) zero_point = qmin - torch.round(tracked_min / scale) elif quant_scheme in ['symmetric', None]: raise ValueError(f'Unsupported quant_scheme {quant_scheme}') else: raise RuntimeError(f'Unknown quant_scheme {quant_scheme}') zero_point = torch.clamp(zero_point, qmin, qmax) return scale, zero_point