Source code for nni.algorithms.compression.pytorch.quantization.native_quantizer

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import torch
from schema import Or, Optional
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer


logger = logging.getLogger(__name__)


[docs]class NaiveQuantizer(Quantizer): """quantize weight to 8 bits """ def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) self.layer_scale = {}
[docs] def validate_config(self, model, config_list): schema = QuantizerSchema([{ Optional('quant_types'): ['weight'], Optional('quant_bits'): Or(8, {'weight': 8}), Optional('op_types'): [str], Optional('op_names'): [str], Optional('exclude'): bool }], model, logger) schema.validate(config_list)
[docs] def quantize_weight(self, wrapper, **kwargs): weight = wrapper.module.weight new_scale = weight.abs().max() / 127 scale = max(self.layer_scale.get(wrapper.name, 0), new_scale) self.layer_scale[wrapper.name] = scale orig_type = weight.type() # TODO: user layer weight = weight.div(scale).type(torch.int8).type(orig_type).mul(scale) wrapper.module.weight = weight return weight