Source code for nni.algorithms.compression.pytorch.quantization.native_quantizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import torch
from schema import Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer

logger = logging.getLogger(__name__)

[docs]class NaiveQuantizer(Quantizer): r""" Quantize weight to 8 bits directly. Parameters ---------- model : torch.nn.Module Model to be quantized. config_list : List[Dict] List of configurations for quantization. Supported keys: - quant_types : List[str] Type of quantization you want to apply, currently support 'weight'. - quant_bits : Union[int, Dict[str, int]] Bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, when the type is int, all quantization types share same bits length. - op_types : List[str] Types of nn.module you want to apply quantization, eg. 'Conv2d'. - op_names : List[str] Names of nn.module you want to apply quantization, eg. 'conv1'. - exclude : bool Set True then the layers setting by op_types and op_names will be excluded from quantization. Examples -------- >>> from nni.algorithms.compression.pytorch.quantization import NaiveQuantizer >>> model = ... >>> NaiveQuantizer(model).compress() """ def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) self.layer_scale = {} def validate_config(self, model, config_list): schema = QuantizerSchema([{ Optional('quant_types'): ['weight'], Optional('quant_bits'): Or(8, {'weight': 8}), Optional('op_types'): [str], Optional('op_names'): [str], Optional('exclude'): bool }], model, logger) schema.validate(config_list) def quantize_weight(self, wrapper, **kwargs): weight = wrapper.module.weight new_scale = weight.abs().max() / 127 scale = max(self.layer_scale.get(, 0), new_scale) self.layer_scale[] = scale orig_type = weight.type() # TODO: user layer weight = weight.div(scale).type(torch.int8).type(orig_type).mul(scale) wrapper.module.weight = weight return weight