Quantizer

Naive Quantizer

class nni.algorithms.compression.pytorch.quantization.NaiveQuantizer(model, config_list, optimizer=None)[source]

Quantize weight to 8 bits directly.

Parameters
  • model (torch.nn.Module) – Model to be quantized.

  • config_list (List[Dict]) –

    List of configurations for quantization. Supported keys:
    • quant_typesList[str]

      Type of quantization you want to apply, currently support ‘weight’, ‘input’, ‘output’.

    • quant_bitsUnion[int, Dict[str, int]]

      Bits length of quantization, key is the quantization type, value is the length, eg. {‘weight’: 8}, when the type is int, all quantization types share same bits length.

    • op_typesList[str]

      Types of nn.module you want to apply quantization, eg. ‘Conv2d’.

    • op_namesList[str]

      Names of nn.module you want to apply quantization, eg. ‘conv1’.

    • excludebool

      Set True then the layers setting by op_types and op_names will be excluded from quantization.

Examples

>>> from nni.algorithms.compression.pytorch.quantization import NaiveQuantizer
>>> model = ...
>>> NaiveQuantizer(model).compress()

QAT Quantizer

class nni.algorithms.compression.pytorch.quantization.QAT_Quantizer(model, config_list, optimizer, dummy_input=None)[source]

Quantizer defined in: Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference

Authors Benoit Jacob and Skirmantas Kligys provide an algorithm to quantize the model with training.

We propose an approach that simulates quantization effects in the forward pass of training. Backpropagation still happens as usual, and all weights and biases are stored in floating point so that they can be easily nudged by small amounts. The forward propagation pass however simulates quantized inference as it will happen in the inference engine, by implementing in floating-point arithmetic the rounding behavior of the quantization scheme:

  • Weights are quantized before they are convolved with the input. If batch normalization (see [17]) is used for the layer, the batch normalization parameters are “folded into” the weights before quantization.

  • Activations are quantized at points where they would be during inference, e.g. after the activation function is applied to a convolutional or fully connected layer’s output, or after a bypass connection adds or concatenates the outputs of several layers together such as in ResNets.

Parameters
  • model (torch.nn.Module) – Model to be quantized.

  • config_list (List[Dict]) –

    List of configurations for quantization. Supported keys for dict:
    • quant_typesList[str]

      Type of quantization you want to apply, currently support ‘weight’, ‘input’, ‘output’.

    • quant_bitsUnion[int, Dict[str, int]]

      Bits length of quantization, key is the quantization type, value is the length, eg. {‘weight’: 8}, When the type is int, all quantization types share same bits length.

    • quant_start_stepint

      Disable quantization until model are run by certain number of steps, this allows the network to enter a more stable. State where output quantization ranges do not exclude a significant fraction of values, default value is 0.

    • op_typesList[str]

      Types of nn.module you want to apply quantization, eg. ‘Conv2d’.

    • op_namesList[str]

      Names of nn.module you want to apply quantization, eg. ‘conv1’.

    • excludebool

      Set True then the layers setting by op_types and op_names will be excluded from quantization.

  • optimizer (torch.optim.Optimizer) – Optimizer is required in QAT_Quantizer, NNI will patch the optimizer and count the optimize step number.

  • dummy_input (Tuple[torch.Tensor]) – Inputs to the model, which are used to get the graph of the module. The graph is used to find Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not given, the batch normalization folding would be disabled.

Examples

>>> from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
>>> model = ...
>>> config_list = [{'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, 'op_types': ['Conv2d']}]
>>> optimizer = ...
>>> dummy_input = torch.rand(...)
>>> quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy_input)
>>> quantizer.compress()
>>> # Training Process...

For detailed example please refer to examples/model_compress/quantization/QAT_torch_quantizer.py.

Notes

Batch normalization folding

Batch normalization folding is supported in QAT quantizer. It can be easily enabled by passing an argument dummy_input to the quantizer, like:

# assume your model takes an input of shape (1, 1, 28, 28)
# and dummy_input must be on the same device as the model
dummy_input = torch.randn(1, 1, 28, 28)

# pass the dummy_input to the quantizer
quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input=dummy_input)

The quantizer will automatically detect Conv-BN patterns and simulate batch normalization folding process in the training graph. Note that when the quantization aware training process is finished, the folded weight/bias would be restored after calling quantizer.export_model.

Quantization dtype and scheme customization

Different backends on different devices use different quantization strategies (i.e. dtype (int or uint) and scheme (per-tensor or per-channel and symmetric or affine)). QAT quantizer supports customization of mainstream dtypes and schemes. There are two ways to set them. One way is setting them globally through a function named set_quant_scheme_dtype like:

from nni.compression.pytorch.quantization.settings import set_quant_scheme_dtype

# This will set all the quantization of 'input' in 'per_tensor_affine' and 'uint' manner
set_quant_scheme_dtype('input', 'per_tensor_affine', 'uint)
# This will set all the quantization of 'output' in 'per_tensor_symmetric' and 'int' manner
set_quant_scheme_dtype('output', 'per_tensor_symmetric', 'int')
# This will set all the quantization of 'weight' in 'per_channel_symmetric' and 'int' manner
set_quant_scheme_dtype('weight', 'per_channel_symmetric', 'int')

The other way is more detailed. You can customize the dtype and scheme in each quantization config list like:

config_list = [{
    'quant_types': ['weight'],
    'quant_bits':  8,
    'op_types':['Conv2d', 'Linear'],
    'quant_dtype': 'int',
    'quant_scheme': 'per_channel_symmetric'
}, {
    'quant_types': ['output'],
    'quant_bits': 8,
    'quant_start_step': 7000,
    'op_types':['ReLU6'],
    'quant_dtype': 'uint',
    'quant_scheme': 'per_tensor_affine'
}]

Multi-GPU training

QAT quantizer natively supports multi-gpu training (DataParallel and DistributedDataParallel). Note that the quantizer instantiation should happen before you wrap your model with DataParallel or DistributedDataParallel. For example:

from torch.nn.parallel import DistributedDataParallel as DDP
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer

model = define_your_model()

model = QAT_Quantizer(model, **other_params)  # <--- QAT_Quantizer instantiation

model = DDP(model)

for i in range(epochs):
    train(model)
    eval(model)

DoReFa Quantizer

class nni.algorithms.compression.pytorch.quantization.DoReFaQuantizer(model, config_list, optimizer)[source]

Quantizer using the DoReFa scheme, as defined in: DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients, authors Shuchang Zhou and Yuxin Wu provide an algorithm named DoReFa to quantize the weight, activation and gradients with training.

Parameters
  • model (torch.nn.Module) – Model to be quantized.

  • config_list (List[Dict]) –

    List of configurations for quantization. Supported keys for dict:
    • quant_typesList[str]

      Type of quantization you want to apply, currently support ‘weight’, ‘input’, ‘output’.

    • quant_bitsUnion[int, Dict[str, int]]

      Bits length of quantization, key is the quantization type, value is the length, eg. {‘weight’: 8}, When the type is int, all quantization types share same bits length.

    • op_typesList[str]

      Types of nn.module you want to apply quantization, eg. ‘Conv2d’.

    • op_namesList[str]

      Names of nn.module you want to apply quantization, eg. ‘conv1’.

    • excludebool

      Set True then the layers setting by op_types and op_names will be excluded from quantization.

  • optimizer (torch.optim.Optimizer) – Optimizer is required in DoReFaQuantizer, NNI will patch the optimizer and count the optimize step number.

Examples

>>> from nni.algorithms.compression.pytorch.quantization import DoReFaQuantizer
>>> model = ...
>>> config_list = [{'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, 'op_types': ['Conv2d']}]
>>> optimizer = ...
>>> quantizer = DoReFaQuantizer(model, config_list, optimizer)
>>> quantizer.compress()
>>> # Training Process...

For detailed example please refer to examples/model_compress/quantization/DoReFaQuantizer_torch_mnist.py.

BNN Quantizer

class nni.algorithms.compression.pytorch.quantization.BNNQuantizer(model, config_list, optimizer)[source]

Binarized Neural Networks, as defined in: Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1,

We introduce a method to train Binarized Neural Networks (BNNs) - neural networks with binary weights and activations at run-time. At training-time the binary weights and activations are used for computing the parameters gradients. During the forward pass, BNNs drastically reduce memory size and accesses, and replace most arithmetic operations with bit-wise operations, which is expected to substantially improve power-efficiency.

Parameters
  • model (torch.nn.Module) – Model to be quantized.

  • config_list (List[Dict]) –

    List of configurations for quantization. Supported keys for dict:
    • quant_typesList[str]

      Type of quantization you want to apply, currently support ‘weight’, ‘input’, ‘output’.

    • quant_bitsUnion[int, Dict[str, int]]

      Bits length of quantization, key is the quantization type, value is the length, eg. {‘weight’: 8}, When the type is int, all quantization types share same bits length.

    • op_typesList[str]

      Types of nn.module you want to apply quantization, eg. ‘Conv2d’.

    • op_namesList[str]

      Names of nn.module you want to apply quantization, eg. ‘conv1’.

    • excludebool

      Set True then the layers setting by op_types and op_names will be excluded from quantization.

  • optimizer (torch.optim.Optimizer) – Optimizer is required in BNNQuantizer, NNI will patch the optimizer and count the optimize step number.

Examples

>>> from nni.algorithms.compression.pytorch.quantization import BNNQuantizer
>>> model = ...
>>> config_list = [{'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, 'op_types': ['Conv2d']}]
>>> optimizer = ...
>>> quantizer = BNNQuantizer(model, config_list, optimizer)
>>> quantizer.compress()
>>> # Training Process...

For detailed example please refer to examples/model_compress/quantization/BNN_quantizer_cifar10.py.

Notes

Results

We implemented one of the experiments in Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1, we quantized the VGGNet for CIFAR-10 in the paper. Our experiments results are as follows:

Model - Accuracy

VGGNet - 86.93%

The experiments code can be found at examples/model_compress/quantization/BNN_quantizer_cifar10.py

LSQ Quantizer

class nni.algorithms.compression.pytorch.quantization.LsqQuantizer(model, config_list, optimizer, dummy_input=None)[source]

Quantizer defined in: LEARNED STEP SIZE QUANTIZATION, authors Steven K. Esser and Jeffrey L. McKinstry provide an algorithm to train the scales with gradients.

The authors introduce a novel means to estimate and scale the task loss gradient at each weight and activation layer’s quantizer step size, such that it can be learned in conjunction with other network parameters.

Parameters
  • model (torch.nn.Module) – The model to be quantized.

  • config_list (List[Dict]) –

    List of configurations for quantization. Supported keys for dict:
    • quant_typesList[str]

      Type of quantization you want to apply, currently support ‘weight’, ‘input’, ‘output’.

    • quant_bitsUnion[int, Dict[str, int]]

      Bits length of quantization, key is the quantization type, value is the length, eg. {‘weight’: 8}, When the type is int, all quantization types share same bits length.

    • op_typesList[str]

      Types of nn.module you want to apply quantization, eg. ‘Conv2d’.

    • op_namesList[str]

      Names of nn.module you want to apply quantization, eg. ‘conv1’.

    • excludebool

      Set True then the layers setting by op_types and op_names will be excluded from quantization.

  • optimizer (torch.optim.Optimizer) – Optimizer is required in LsqQuantizer, NNI will patch the optimizer and count the optimize step number.

  • dummy_input (Tuple[torch.Tensor]) – Inputs to the model, which are used to get the graph of the module. The graph is used to find Conv-Bn patterns. And then the batch normalization folding would be enabled. If dummy_input is not given, the batch normalization folding would be disabled.

Examples

>>> from nni.algorithms.compression.pytorch.quantization import LsqQuantizer
>>> model = ...
>>> config_list = [{'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, 'op_types': ['Conv2d']}]
>>> optimizer = ...
>>> dummy_input = torch.rand(...)
>>> quantizer = LsqQuantizer(model, config_list, optimizer, dummy_input=dummy_input)
>>> quantizer.compress()
>>> # Training Process...

For detailed example please refer to examples/model_compress/quantization/LSQ_torch_quantizer.py.

Observer Quantizer

class nni.algorithms.compression.pytorch.quantization.ObserverQuantizer(model, config_list, optimizer=None)[source]

Observer quantizer is a framework of post-training quantization. It will insert observers into the place where the quantization will happen. During quantization calibration, each observer will record all the tensors it ‘sees’. These tensors will be used to calculate the quantization statistics after calibration.

The whole process can be divided into three steps:

  1. It will register observers to the place where quantization would happen (just like registering hooks).

  2. The observers would record tensors’ statistics during calibration.

  3. Scale & zero point would be obtained after calibration.

Note that the observer type, tensor dtype and quantization qscheme are hard coded for now. Their customization are under development and will be ready soon.

Parameters
  • model (torch.nn.Module) – Model to be quantized.

  • config_list (List[Dict]) –

    List of configurations for quantization. Supported keys:
    • quant_typesList[str]

      Type of quantization you want to apply, currently support ‘weight’, ‘input’, ‘output’.

    • quant_bitsUnion[int, Dict[str, int]]

      Bits length of quantization, key is the quantization type, value is the length, eg. {‘weight’: 8}, when the type is int, all quantization types share same bits length.

    • op_typesList[str]

      Types of nn.module you want to apply quantization, eg. ‘Conv2d’.

    • op_namesList[str]

      Names of nn.module you want to apply quantization, eg. ‘conv1’.

    • excludebool

      Set True then the layers setting by op_types and op_names will be excluded from quantization.

  • optimizer (torch.optim.Optimizer) – Optimizer is optional in ObserverQuantizer.

Examples

>>> from nni.algorithms.compression.pytorch.quantization import ObserverQuantizer
>>> model = ...
>>> config_list = [{'quant_types': ['weight', 'input'], 'quant_bits': {'weight': 8, 'input': 8}, 'op_types': ['Conv2d']}]
>>> quantizer = ObserverQuantizer(model, config_list)
>>> # define a calibration function
>>> def calibration(model, calib_loader):
>>>     model.eval()
>>>     with torch.no_grad():
>>>         for data, _ in calib_loader:
>>>             model(data)
>>> calibration(model, calib_loader)
>>> quantizer.compress()

For detailed example please refer to examples/model_compress/quantization/observer_quantizer.py.

Note

This quantizer is still under development for now. Some quantizer settings are hard-coded:

  • weight observer: per_tensor_symmetric, qint8

  • output observer: per_tensor_affine, quint8, reduce_range=True

Other settings (such as quant_type and op_names) can be configured.

Notes

About the compress API

Before the compress API is called, the model will only record tensors’ statistics and no quantization process will be executed. After the compress API is called, the model will NOT record tensors’ statistics any more. The quantization scale and zero point will be generated for each tensor and will be used to quantize each tensor during inference (we call it evaluation mode)

About calibration

Usually we pick up about 100 training/evaluation examples for calibration. If you found the accuracy is a bit low, try to reduce the number of calibration examples.