Source code for nni.algorithms.compression.pytorch.quantization.qat_quantizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantGrad
from nni.compression.pytorch.quantization.literal import (
    PER_CHANNEL_QUANT_SCHEME,
    QuantScheme,
    QuantDtype,
    QuantType
)
from nni.compression.pytorch.quantization.settings import LayerQuantSetting
from nni.compression.pytorch.quantization.utils import (
    calculate_qmin_qmax,
    get_min_max_value,
    get_quant_shape
)


logger = logging.getLogger(__name__)


class QATGrad(QuantGrad):
    @staticmethod
    def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax):
        tensor_q = QuantGrad._quantize(tensor, scale, zero_point)
        mask = (tensor_q < qmin) | (tensor_q > qmax)
        grad_output[mask] = 0
        return grad_output


def update_quantization_param(bits, rmin, rmax, dtype, scheme):
    """
    calculate the `zero_point` and `scale`.

    Parameters
    ----------
    bits : int
        quantization bits length
    rmin : Tensor
        min value of real value
    rmax : Tensor
        max value of real value
    dtype : QuantDtype
        quantized data type
    scheme : QuantScheme
        quantization scheme to be used
    Returns
    -------
    float, float
    """

    # extend the [min, max] interval to ensure that it contains 0.
    # Otherwise, we would not meet the requirement that 0 be an exactly
    # representable value.
    # I think this is for activations that need to be pad in the training.
    # However this is a default behavior in PyTorch quantization observer.
    # So we also make it a default behavior
    rmin = torch.min(rmin, torch.zeros_like(rmin))
    rmax = torch.max(rmax, torch.zeros_like(rmax))
    zero_point = torch.zeros_like(rmin)

    # todo: there is no need to calculate qmin and qmax again
    qmin, qmax = calculate_qmin_qmax(bits, dtype)

    if scheme in [QuantScheme.PER_TENSOR_SYMMETRIC, QuantScheme.PER_CHANNEL_SYMMETRIC]:
        abs_max = torch.max(torch.abs(rmin), torch.abs(rmax))
        scale = abs_max / (float(qmax - qmin) / 2)
        if dtype == QuantDtype.UINT:
            zero_point_val = (qmin + qmax) // 2
            zero_point = zero_point.new_full(zero_point.size(), zero_point_val)
    else:
        scale = (rmax - rmin) / float(qmax - qmin)
        zero_point = qmin - torch.round(rmin / scale)

    zero_point = torch.clamp(zero_point, qmin, qmax)

    # todo: add these lines
    # eps = torch.finfo(torch.float32).eps
    # scale = torch.max(scale, eps)

    return scale, zero_point


def update_ema(biased_ema, value, decay):
    """
    calculate biased stat and unbiased stat in each step using exponential moving average method

    Parameters
    ----------
    biased_ema : float
        previous stat value
    value : float
        current stat value
    decay : float
        the weight of previous stat value, larger means smoother curve

    Returns
    -------
    float, float
    """
    biased_ema = biased_ema * decay + (1 - decay) * value
    return biased_ema


[docs]class QAT_Quantizer(Quantizer): r""" Quantizer defined in: `Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference <http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf>`__ Authors Benoit Jacob and Skirmantas Kligys provide an algorithm to quantize the model with training. .. We propose an approach that simulates quantization effects in the forward pass of training. Backpropagation still happens as usual, and all weights and biases are stored in floating point so that they can be easily nudged by small amounts. The forward propagation pass however simulates quantized inference as it will happen in the inference engine, by implementing in floating-point arithmetic the rounding behavior of the quantization scheme: * Weights are quantized before they are convolved with the input. If batch normalization (see [17]) is used for the layer, the batch normalization parameters are “folded into” the weights before quantization. * Activations are quantized at points where they would be during inference, e.g. after the activation function is applied to a convolutional or fully connected layer’s output, or after a bypass connection adds or concatenates the outputs of several layers together such as in ResNets. Parameters ---------- model : torch.nn.Module Model to be quantized. config_list : List[Dict] List of configurations for quantization. Supported keys for dict: - quant_types : List[str] Type of quantization you want to apply, currently support 'weight', 'input', 'output'. - quant_bits : Union[int, Dict[str, int]] Bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, When the type is int, all quantization types share same bits length. - quant_start_step : int Disable quantization until model are run by certain number of steps, this allows the network to enter a more stable. State where output quantization ranges do not exclude a significant fraction of values, default value is 0. - op_types : List[str] Types of nn.module you want to apply quantization, eg. 'Conv2d'. - op_names : List[str] Names of nn.module you want to apply quantization, eg. 'conv1'. - exclude : bool Set True then the layers setting by op_types and op_names will be excluded from quantization. optimizer : torch.optim.Optimizer Optimizer is required in `QAT_Quantizer`, NNI will patch the optimizer and count the optimize step number. dummy_input : Tuple[torch.Tensor] Inputs to the model, which are used to get the graph of the module. The graph is used to find Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not given, the batch normalization folding would be disabled. Examples -------- >>> from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer >>> model = ... >>> config_list = [{'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, 'op_types': ['Conv2d']}] >>> optimizer = ... >>> dummy_input = torch.rand(...) >>> quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy_input) >>> quantizer.compress() >>> # Training Process... For detailed example please refer to :githublink:`examples/model_compress/quantization/QAT_torch_quantizer.py <examples/model_compress/quantization/QAT_torch_quantizer.py>`. Notes ----- **Batch normalization folding** Batch normalization folding is supported in QAT quantizer. It can be easily enabled by passing an argument `dummy_input` to the quantizer, like: .. code-block:: python # assume your model takes an input of shape (1, 1, 28, 28) # and dummy_input must be on the same device as the model dummy_input = torch.randn(1, 1, 28, 28) # pass the dummy_input to the quantizer quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy_input) The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling `quantizer.export_model`. **Quantization dtype and scheme customization** Different backends on different devices use different quantization strategies (i.e. dtype (int or uint) and scheme (per-tensor or per-channel and symmetric or affine)). QAT quantizer supports customization of mainstream dtypes and schemes. There are two ways to set them. One way is setting them globally through a function named `set_quant_scheme_dtype` like: .. code-block:: python from nni.compression.pytorch.quantization.settings import set_quant_scheme_dtype # This will set all the quantization of 'input' in 'per_tensor_affine' and 'uint' manner set_quant_scheme_dtype('input', 'per_tensor_affine', 'uint) # This will set all the quantization of 'output' in 'per_tensor_symmetric' and 'int' manner set_quant_scheme_dtype('output', 'per_tensor_symmetric', 'int') # This will set all the quantization of 'weight' in 'per_channel_symmetric' and 'int' manner set_quant_scheme_dtype('weight', 'per_channel_symmetric', 'int') The other way is more detailed. You can customize the dtype and scheme in each quantization config list like: .. code-block:: python config_list = [{ 'quant_types': ['weight'], 'quant_bits': 8, 'op_types':['Conv2d', 'Linear'], 'quant_dtype': 'int', 'quant_scheme': 'per_channel_symmetric' }, { 'quant_types': ['output'], 'quant_bits': 8, 'quant_start_step': 7000, 'op_types':['ReLU6'], 'quant_dtype': 'uint', 'quant_scheme': 'per_tensor_affine' }] **Multi-GPU training** QAT quantizer natively supports multi-gpu training (DataParallel and DistributedDataParallel). Note that the quantizer instantiation should happen before you wrap your model with DataParallel or DistributedDataParallel. For example: .. code-block:: python from torch.nn.parallel import DistributedDataParallel as DDP from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer model = define_your_model() model = QAT_Quantizer(model, **other_params) # <--- QAT_Quantizer instantiation model = DDP(model) for i in range(epochs): train(model) eval(model) """ def __init__(self, model, config_list, optimizer, dummy_input=None): assert isinstance(optimizer, torch.optim.Optimizer), "unrecognized optimizer type" super().__init__(model, config_list, optimizer, dummy_input) self.quant_grad = QATGrad.apply modules_to_compress = self.get_modules_to_compress() device = next(model.parameters()).device self.bound_model.register_buffer("steps", torch.tensor(1)) for layer, config in modules_to_compress: module = layer.module name = layer.name # TODO: may relax this limitation? assert name in self.all_shapes, "Could not found shapes for layer {}".format(name) input_shape, output_shape = self.all_shapes[name] layer_quant_setting = LayerQuantSetting(config) layer_quant_setting.ema_decay = 0.99 quant_start_step = config.get('quant_start_step', 0) layer_quant_setting.quant_start_step = quant_start_step # todo: support other ranks and remove this check if isinstance(module, torch.nn.Linear): if "input" in config.get("quant_types", []) and \ layer_quant_setting.input.quant_scheme in PER_CHANNEL_QUANT_SCHEME: if len(input_shape) != 2: logger.warning("When quantize torch.nn.Linear, make sure that the rank of the inputs " "of the layer is 2. Skip quantization of layer %s.", name) continue if "output" in config.get("quant_types", []) and \ layer_quant_setting.output.quant_scheme in PER_CHANNEL_QUANT_SCHEME: if len(output_shape) != 2: logger.warning("When quantize torch.nn.Linear, make sure that the rank of the outputs " "of the layer is 2. Skip quantization of layer %s.", name) continue if "weight" in config.get("quant_types", []): quant_shape = get_quant_shape(module.weight.shape, QuantType.WEIGHT, layer_quant_setting.weight.quant_scheme) module.register_buffer('weight_scale', torch.zeros(quant_shape)) module.register_buffer('weight_zero_point', torch.zeros(quant_shape)) if "input" in config.get("quant_types", []): quant_shape = get_quant_shape(input_shape, QuantType.INPUT, layer_quant_setting.input.quant_scheme) module.register_buffer('tracked_min_input', torch.zeros(quant_shape)) module.register_buffer('tracked_max_input', torch.zeros(quant_shape)) module.register_buffer('input_scale', torch.zeros(quant_shape)) module.register_buffer('input_zero_point', torch.zeros(quant_shape)) if "output" in config.get("quant_types", []): quant_shape = get_quant_shape(output_shape, QuantType.OUTPUT, layer_quant_setting.output.quant_scheme) module.register_buffer('tracked_min_output', torch.zeros(quant_shape)) module.register_buffer('tracked_max_output', torch.zeros(quant_shape)) module.register_buffer('output_scale', torch.zeros(quant_shape)) module.register_buffer('output_zero_point', torch.zeros(quant_shape)) setattr(module, "layer_quant_setting", layer_quant_setting) self.bound_model.to(device) def _del_simulated_attr(self, module): """ delete redundant parameters in quantize module """ del_attr_list = ['old_weight', 'old_bias', 'ema_decay', 'tracked_min_output', 'tracked_max_output', 'tracked_min_input', 'tracked_max_input', 'BN_FOLD_TAG', 'weight_scale', 'weight_zero_point', 'input_scale', 'input_zero_point', 'output_scale', 'output_zero_point', 'layer_quant_setting'] for attr in del_attr_list: if hasattr(module, attr): delattr(module, attr) def validate_config(self, model, config_list): """ Parameters ---------- model : torch.nn.Module Model to be pruned config_list : list of dict List of configurations """ SUPPORTED_OPS = ['Conv2d', 'Linear', 'ReLU', 'ReLU6'] schema = QuantizerSchema([{ Optional('quant_types'): Schema([lambda x: x in ['weight', 'output', 'input']]), Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({ Optional('input'): And(int, lambda n: 0 < n < 32), Optional('weight'): And(int, lambda n: 0 < n < 32), Optional('output'): And(int, lambda n: 0 < n < 32), })), Optional('quant_scheme'): Or(lambda x: x in QuantScheme, Schema({ Optional('input'): lambda x: x in QuantScheme, Optional('weight'): lambda x: x in QuantScheme, Optional('output'): lambda x: x in QuantScheme })), Optional('quant_dtype'): Or(lambda x: x in QuantDtype, Schema({ Optional('input'): lambda x: x in QuantDtype, Optional('weight'): lambda x: x in QuantDtype, Optional('output'): lambda x: x in QuantDtype })), Optional('quant_start_step'): And(int, lambda n: n >= 0), Optional('op_types'): [And(str, lambda n: n in SUPPORTED_OPS)], Optional('op_names'): [str], Optional('exclude'): bool }], model, logger) schema.validate(config_list) def _quantize(self, real_value, scale, zero_point, qmin, qmax): """ quantize real value. Parameters ---------- real_value : torch.Tensor the real value to be quantized scale : torch.Tensor quantization scale zero_point : torch.Tensor quantization zero point qmin : int lower bound of the int range qmax : int upper bound of the int range Returns ------- Tensor """ transformed_val = zero_point + real_value / scale clamped_val = torch.clamp(transformed_val, qmin, qmax) quantized_val = torch.round(clamped_val) return quantized_val def _dequantize(self, quantized_val, scale, zero_point): """ dequantize quantized value. Because we simulate quantization in training process, all the computations still happen as float point computations, which means we first quantize tensors then dequantize them. For more details, please refer to the paper. Parameters ---------- quantized_val : torch.Tensor the quantized value to be de-quantized scale : torch.Tensor quantization scale zero_point : torch.Tensor quantization zero point Returns ------- Tensor """ real_val = scale * (quantized_val - zero_point) return real_val def quantize_weight(self, wrapper, **kwargs): module = wrapper.module weight = module.weight layer_quant_setting = module.layer_quant_setting tensor_quant_setting = layer_quant_setting.weight # layer-wise settings quant_start_step = layer_quant_setting.quant_start_step # tensor-wise settings dtype = tensor_quant_setting.quant_dtype scheme = tensor_quant_setting.quant_scheme qmin, qmax = tensor_quant_setting.get_qmin_qmax() bits = tensor_quant_setting.bits # In evaluation mode, we only quantize weight without updating statistics if not wrapper.training: scale, zero_point = module.weight_scale, module.weight_zero_point weight = self._quantize(weight, scale, zero_point, qmin, qmax) weight = self._dequantize(weight, scale, zero_point) module.weight = weight return weight if quant_start_step > int(self.bound_model.steps): return weight current_min, current_max = get_min_max_value(weight, QuantType.WEIGHT, scheme) scale, zero_point = update_quantization_param(bits, current_min, current_max, dtype, scheme) module.weight_scale.copy_(scale) module.weight_zero_point.copy_(zero_point) weight = self._quantize(weight, scale, zero_point, qmin, qmax) weight = self._dequantize(weight, scale, zero_point) # Weight can not be in-place modified, so when use torch.nn.DataParallel, this update # will be lost after each forward process. However, this update takes effect on each # replicated module during each forward process, which will make the quantized weight # be used correctly. wrapper.module.weight = weight return weight def quantize_input(self, inputs, wrapper, **kwargs): module = wrapper.module layer_quant_setting = module.layer_quant_setting tensor_quant_setting = layer_quant_setting.input # layer-wise settings quant_start_step = layer_quant_setting.quant_start_step ema_decay = layer_quant_setting.ema_decay # tensor-wise settings dtype = tensor_quant_setting.quant_dtype scheme = tensor_quant_setting.quant_scheme qmin, qmax = tensor_quant_setting.get_qmin_qmax() bits = tensor_quant_setting.bits if not wrapper.training: scale = module.input_scale zero_point = module.input_zero_point inputs = self._quantize(inputs, scale, zero_point, qmin, qmax) inputs = self._dequantize(inputs, scale, zero_point) return inputs current_min, current_max = get_min_max_value(inputs, QuantType.INPUT, scheme) if int(self.bound_model.steps) == 1: module.tracked_min_input.copy_(current_min) module.tracked_max_input.copy_(current_max) tracked_min_input = update_ema(module.tracked_min_input, current_min, ema_decay) tracked_max_input = update_ema(module.tracked_max_input, current_max, ema_decay) module.tracked_min_input.copy_(tracked_min_input) module.tracked_max_input.copy_(tracked_max_input) if quant_start_step > int(self.bound_model.steps): return inputs scale, zero_point = update_quantization_param( bits, module.tracked_min_input, module.tracked_max_input, dtype, scheme) module.input_scale.copy_(scale) module.input_zero_point.copy_(zero_point) inputs = self._quantize(inputs, scale, zero_point, qmin, qmax) inputs = self._dequantize(inputs, scale, zero_point) return inputs def quantize_output(self, output, wrapper, **kwargs): module = wrapper.module layer_quant_setting = module.layer_quant_setting tensor_quant_setting = layer_quant_setting.output # layer-wise settings quant_start_step = layer_quant_setting.quant_start_step ema_decay = layer_quant_setting.ema_decay # tensor-wise settings dtype = tensor_quant_setting.quant_dtype scheme = tensor_quant_setting.quant_scheme qmin, qmax = tensor_quant_setting.get_qmin_qmax() bits = tensor_quant_setting.bits if not wrapper.training: scale = module.output_scale zero_point = module.output_zero_point output = self._quantize(output, scale, zero_point, qmin, qmax) output = self._dequantize(output, scale, zero_point) return output current_min, current_max = get_min_max_value(output, QuantType.OUTPUT, scheme) if int(self.bound_model.steps) == 1: module.tracked_min_output.copy_(current_min) module.tracked_max_output.copy_(current_max) tracked_min_output = update_ema(module.tracked_min_output, current_min, ema_decay) tracked_max_output = update_ema(module.tracked_max_output, current_max, ema_decay) module.tracked_min_output.copy_(tracked_min_output) module.tracked_max_output.copy_(tracked_max_output) if quant_start_step > int(self.bound_model.steps): return output scale, zero_point = update_quantization_param( bits, module.tracked_min_output, module.tracked_max_output, dtype, scheme) module.output_scale.copy_(scale) module.output_zero_point.copy_(zero_point) output = self._quantize(output, scale, zero_point, qmin, qmax) output = self._dequantize(output, scale, zero_point) return output def load_calibration_config(self, calibration_config): modules_to_compress = self.get_modules_to_compress() for layer, _ in modules_to_compress: name, module = layer.name, layer.module if name not in calibration_config: if module.layer_quant_setting.weight or module.layer_quant_setting.input or module.layer_quant_setting.output: logger.warning(f"Can not find module {name}'s parameter in input config.") continue if module.layer_quant_setting.weight: assert calibration_config[name]['weight_bits'] == module.layer_quant_setting.weight.bits, \ f"weight bits of module {name} fail to match" if module.layer_quant_setting.input: assert calibration_config[name]['input_bits'] == module.layer_quant_setting.input.bits, \ f"input bits of module {name} fail to match" module.tracked_min_input.data = torch.tensor([calibration_config[name]['tracked_min_input']]) module.tracked_max_input.data = torch.tensor([calibration_config[name]['tracked_max_input']]) if module.layer_quant_setting.output: assert calibration_config[name]['output_bits'] == module.layer_quant_setting.output.bits, \ f"output bits of module {name} fail to match" module.tracked_min_output.data = torch.tensor([calibration_config[name]['tracked_min_output']]) module.tracked_max_output.data = torch.tensor([calibration_config[name]['tracked_max_output']]) def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): """ Export quantized model weights and calibration parameters(optional) Parameters ---------- model_path : str path to save quantized model weight calibration_path : str (optional) path to save quantize parameters after calibration onnx_path : str (optional) path to save onnx model input_shape : list or tuple input shape to onnx model device : torch.device device of the model, used to place the dummy input tensor for exporting onnx file. the tensor is placed on cpu if ```device``` is None Returns ------- Dict """ assert model_path is not None, 'model_path must be specified' self._unwrap_model() calibration_config = {} modules_to_compress = self.get_modules_to_compress() for layer, _ in modules_to_compress: name, module = layer.name, layer.module if hasattr(module.layer_quant_setting, 'weight') or hasattr(module.layer_quant_setting, 'output'): calibration_config[name] = {} if module.layer_quant_setting.weight: calibration_config[name]['weight_bits'] = int(module.layer_quant_setting.weight.bits) calibration_config[name]['weight_scale'] = module.weight_scale calibration_config[name]['weight_zero_point'] = module.weight_zero_point # Recover weight/bias for batch normalization folding actual_weight = getattr(module, 'old_weight', None) if actual_weight is None: logger.warning("Can not recover weight for layer %s. " "This may lead to a wrong accuracy performance on the backend.", name) delattr(module, 'weight') module.register_parameter('weight', actual_weight) if hasattr(module, BN_FOLD_TAG): actual_bias = getattr(module, 'old_bias', None) delattr(module, 'bias') if actual_bias is not None: module.register_parameter('bias', actual_bias) else: setattr(module, 'bias', None) if module.layer_quant_setting.input: calibration_config[name]['input_bits'] = int(module.layer_quant_setting.input.bits) calibration_config[name]['tracked_min_input'] = float(module.tracked_min_input) calibration_config[name]['tracked_max_input'] = float(module.tracked_max_input) if module.layer_quant_setting.output: calibration_config[name]['output_bits'] = int(module.layer_quant_setting.output.bits) calibration_config[name]['tracked_min_output'] = float(module.tracked_min_output) calibration_config[name]['tracked_max_output'] = float(module.tracked_max_output) self._del_simulated_attr(module) self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) return calibration_config def step_with_optimizer(self): """ override `compressor` `step` method, quantization only happens after certain number of steps """ self.bound_model.steps.add_(1)