Source code for nni.algorithms.compression.pytorch.quantization.lsq_quantizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import torch
from nni.compression.pytorch.compressor import BN_FOLD_TAG, Quantizer, QuantForward
from nni.compression.pytorch.quantization.utils import get_bits_length


logger = logging.getLogger(__name__)


[docs]class LsqQuantizer(Quantizer): r""" Quantizer defined in: `LEARNED STEP SIZE QUANTIZATION <https://arxiv.org/pdf/1902.08153.pdf>`__, authors Steven K. Esser and Jeffrey L. McKinstry provide an algorithm to train the scales with gradients. .. The authors introduce a novel means to estimate and scale the task loss gradient at each weight and activation layer's quantizer step size, such that it can be learned in conjunction with other network parameters. Parameters ---------- model : torch.nn.Module The model to be quantized. config_list : List[Dict] List of configurations for quantization. Supported keys for dict: - quant_types : List[str] Type of quantization you want to apply, currently support 'weight', 'input', 'output'. - quant_bits : Union[int, Dict[str, int]] Bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, When the type is int, all quantization types share same bits length. - op_types : List[str] Types of nn.module you want to apply quantization, eg. 'Conv2d'. - op_names : List[str] Names of nn.module you want to apply quantization, eg. 'conv1'. - exclude : bool Set True then the layers setting by op_types and op_names will be excluded from quantization. optimizer : torch.optim.Optimizer Optimizer is required in `LsqQuantizer`, NNI will patch the optimizer and count the optimize step number. dummy_input : Tuple[torch.Tensor] Inputs to the model, which are used to get the graph of the module. The graph is used to find Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not given, the batch normalization folding would be disabled. Examples -------- >>> from nni.algorithms.compression.pytorch.quantization import LsqQuantizer >>> model = ... >>> config_list = [{'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, 'op_types': ['Conv2d']}] >>> optimizer = ... >>> dummy_input = torch.rand(...) >>> quantizer = LsqQuantizer(model, config_list, optimizer, dummy_input=dummy_input) >>> quantizer.compress() >>> # Training Process... For detailed example please refer to :githublink:`examples/model_compress/quantization/LSQ_torch_quantizer.py <examples/model_compress/quantization/LSQ_torch_quantizer.py>`. """ def __init__(self, model, config_list, optimizer, dummy_input=None): assert isinstance(optimizer, torch.optim.Optimizer), "unrecognized optimizer type" super().__init__(model, config_list, optimizer, dummy_input) device = next(model.parameters()).device self.quant_grad = QuantForward() modules_to_compress = self.get_modules_to_compress() self.bound_model.register_buffer("steps", torch.Tensor([1])) for layer, config in modules_to_compress: if "weight" in config.get("quant_types", []): layer.module.register_parameter("weight_scale", torch.nn.Parameter(torch.Tensor([1.0]))) # todo: support per-channel quantization for weight since TensorRT use it for conv weight weight_bits = get_bits_length(config, "weight") layer.module.register_buffer('weight_bits', torch.Tensor([weight_bits])) qmax = 2 ** (weight_bits - 1) - 1 qmin = -2 ** (weight_bits - 1) init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5) layer.module.weight_scale = torch.nn.Parameter(init_weight_scale) layer.module.weight_qmax = qmax layer.module.weight_qmin = qmin self.optimizer.add_param_group({"params": layer.module.weight_scale}) if "output" in config.get("quant_types", []): # scale of output will be initialized using the first batch data layer.module.register_parameter("output_scale", torch.nn.Parameter(torch.Tensor([1.0]))) output_bits = get_bits_length(config, "output") layer.module.register_buffer('output_bits', torch.Tensor([output_bits])) qmax = 2 ** (output_bits - 1) - 1 qmin = -2 ** (output_bits - 1) layer.module.output_qmax = qmax layer.module.output_qmin = qmin self.optimizer.add_param_group({"params": layer.module.output_scale}) if "input" in config.get("quant_types", []): # scale of input will be initialized using the first batch data layer.module.register_parameter("input_scale", torch.nn.Parameter(torch.Tensor([1.0]))) input_bits = get_bits_length(config, "input") layer.module.register_buffer('input_bits', torch.Tensor([input_bits])) qmax = 2 ** (input_bits - 1) - 1 qmin = -2 ** (input_bits - 1) layer.module.input_qmax = qmax layer.module.input_qmin = qmin self.optimizer.add_param_group({"params": layer.module.input_scale}) self.bound_model.to(device) @staticmethod def grad_scale(x, scale): """ Used to scale the gradient. Give tensor `x`, we have `y=grad_scale(x, scale)=x` in the forward pass, which means that this function will not change the value of `x`. In the backward pass, we have: .. math: \frac{\alpha_L}{\alpha_x}=\frac{\alpha_L}{\alpha_y}*\frac{\alpha_y}{\alpha_x}=sclae*\frac{\alpha_L}{\alpha_x} This means that the origin gradient of x is scaled by a factor of `scale`. Applying this function to a nn.Parameter will scale the gradient of it without changing its value. """ y = x y_grad = x * scale return (y - y_grad).detach() + y_grad @staticmethod def round_pass(x): """ A simple way to achieve STE operation. """ y = x.round() y_grad = x return (y - y_grad).detach() + y_grad def quantize(self, x, scale, qmin, qmax): grad_scale_factor = 1.0 / ((qmax * x.numel()) ** 0.5) scale = self.grad_scale(scale, grad_scale_factor) x = x / scale x = torch.clamp(x, qmin, qmax) x = self.round_pass(x) x = x * scale return x def quantize_weight(self, wrapper, **kwargs): module = wrapper.module weight = wrapper.module.weight # todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize # bias weight = self.quantize(weight, module.weight_scale, module.weight_qmin, module.weight_qmax) module.weight = weight return weight def quantize_output(self, output, wrapper, **kwargs): module = wrapper.module # initialize the scale if self.bound_model.steps == 1: qmax = module.output_qmax init_oup_scale = output.data.detach().abs().mean() * 2 / (qmax ** 0.5) module.output_scale.data = init_oup_scale output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax) return output def quantize_input(self, inputs, wrapper, **kwargs): module = wrapper.module # initialize the scale if self.bound_model.steps == 1: qmax = module.input_qmax init_oup_scale = inputs.data.detach().abs().mean() * 2 / (qmax ** 0.5) module.input_scale.data = init_oup_scale inputs = self.quantize(inputs, module.input_scale, module.input_qmin, module.input_qmax) return inputs def load_calibration_config(self, calibration_config): modules_to_compress = self.get_modules_to_compress() for layer, _ in modules_to_compress: name, module = layer.name, layer.module if name not in calibration_config: if hasattr(module, 'weight_bits') or hasattr(module, 'output_bits') or hasattr(module, 'input_bits'): logger.warning(f"Can not find module {name}'s parameter in input config.") continue if hasattr(module, 'weight_bits'): assert calibration_config[name]['weight_bits'] == int(module.weight_bits), f"weight bits of module {name} fail to match" if hasattr(module, 'input_bits'): assert calibration_config[name]['input_bits'] == int(module.input_bits), f"input bits of module {name} fail to match" module.input_scale.data = torch.Tensor([float(calibration_config[name]['tracked_max_input'] / module.input_qmax)]) if hasattr(module, 'output_bits'): assert calibration_config[name]['output_bits'] == int(module.output_bits), f"output bits of module {name} fail to match" module.output_scale.data = torch.Tensor([float(calibration_config[name]['tracked_max_output'] / module.output_qmax)]) def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): """ Export quantized model weights and calibration parameters(optional) Parameters ---------- model_path : str path to save quantized model weight calibration_path : str (optional) path to save quantize parameters after calibration onnx_path : str (optional) path to save onnx model input_shape : list or tuple input shape to onnx model device : torch.device device of the model, used to place the dummy input tensor for exporting onnx file. the tensor is placed on cpu if ```device``` is None Returns ------- Dict """ assert model_path is not None, 'model_path must be specified' self._unwrap_model() calibration_config = {} for name, module in self.bound_model.named_modules(): if hasattr(module, 'input_bits') or hasattr(module, 'weight_bits') or hasattr(module, 'output_bits'): calibration_config[name] = {} if hasattr(module, 'weight_bits'): calibration_config[name]['weight_bits'] = int(module.weight_bits) abs_max_input = float(module.input_scale * module.input_qmax) calibration_config[name]['tracked_min_input'] = -abs_max_input calibration_config[name]['tracked_max_input'] = abs_max_input actual_weight = getattr(module, 'old_weight', None) if actual_weight is None: logger.warning("Can not recover weight for layer %s. " "This may lead to a wrong accuracy performance on the backend.", name) delattr(module, 'weight') module.register_parameter('weight', actual_weight) if hasattr(module, BN_FOLD_TAG): actual_bias = getattr(module, 'old_bias', None) delattr(module, 'bias') if actual_bias is not None: module.register_parameter('bias', actual_bias) else: setattr(module, 'bias', None) if hasattr(module, 'input_bits'): calibration_config[name]['input_bits'] = int(module.input_bits) abs_max_input = float(module.input_scale * module.input_qmax) calibration_config[name]['tracked_min_input'] = -abs_max_input calibration_config[name]['tracked_max_input'] = abs_max_input if hasattr(module, 'output_bits'): calibration_config[name]['output_bits'] = int(module.output_bits) abs_max_output = float(module.output_scale * module.output_qmax) calibration_config[name]['tracked_min_output'] = -abs_max_output calibration_config[name]['tracked_max_output'] = abs_max_output self._del_simulated_attr(module) self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) return calibration_config def _del_simulated_attr(self, module): """ delete redundant parameters in quantize module """ del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_output', \ 'tracked_max_output', 'output_scale', 'input_scale', 'weight_scale','weight_bits', 'output_bits', 'input_bits', 'BN_FOLD_TAG'] for attr in del_attr_list: if hasattr(module, attr): delattr(module, attr) def step_with_optimizer(self): """ override `compressor` `step` method, quantization only happens after certain number of steps """ self.bound_model.steps += 1