nni.compression.pytorch.quantization.native_quantizer 源代码
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from schema import Or, Optional
from nni.compression.pytorch.utils.config_validation_v1 import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer
logger = logging.getLogger(__name__)
[文档]class NaiveQuantizer(Quantizer):
r"""
Quantize weight to 8 bits directly.
Parameters
----------
model : torch.nn.Module
Model to be quantized.
config_list : List[Dict]
List of configurations for quantization. Supported keys:
- quant_types : List[str]
Type of quantization you want to apply, currently support 'weight'.
- quant_bits : Union[int, Dict[str, int]]
Bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8},
when the type is int, all quantization types share same bits length.
- op_types : List[str]
Types of nn.module you want to apply quantization, eg. 'Conv2d'.
- op_names : List[str]
Names of nn.module you want to apply quantization, eg. 'conv1'.
- exclude : bool
Set True then the layers setting by op_types and op_names will be excluded from quantization.
Examples
--------
>>> from nni.compression.pytorch.quantization import NaiveQuantizer
>>> model = ...
>>> NaiveQuantizer(model).compress()
"""
def __init__(self, model, config_list, optimizer=None):
super().__init__(model, config_list, optimizer)
self.layer_scale = {}
def validate_config(self, model, config_list):
schema = QuantizerSchema([{
Optional('quant_types'): ['weight'],
Optional('quant_bits'): Or(8, {'weight': 8}),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
schema.validate(config_list)
def quantize_weight(self, wrapper, **kwargs):
weight = wrapper.module.weight
new_scale = weight.abs().max() / 127
scale = max(self.layer_scale.get(wrapper.name, 0), new_scale)
self.layer_scale[wrapper.name] = scale
orig_type = weight.type() # TODO: user layer
weight = weight.div(scale).type(torch.int8).type(orig_type).mul(scale)
wrapper.module.weight = weight
return weight