Quantization Quickstart

Quantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.

In NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported. Here we use QAT_Quantizer as an example to show the usage of quantization in NNI.

Preparation

In this tutorial, we use a simple model and pre-train on MNIST dataset. If you are familiar with defining a model and training in pytorch, you can skip directly to Quantizing Model.

import torch
import torch.nn.functional as F
from torch.optim import SGD

from scripts.compression_mnist_model import TorchModel, trainer, evaluator, device, test_trt

# define the model
model = TorchModel().to(device)

# define the optimizer and criterion for pre-training

optimizer = SGD(model.parameters(), 1e-2)
criterion = F.nll_loss

# pre-train and evaluate the model on MNIST dataset
for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

Out:

Average test loss: 0.7073, Accuracy: 7624/10000 (76%)
Average test loss: 0.2776, Accuracy: 9122/10000 (91%)
Average test loss: 0.1907, Accuracy: 9412/10000 (94%)

Quantizing Model

Initialize a config_list. Detailed about how to write config_list please refer compression config specification.

config_list = [{
    'quant_types': ['input', 'weight'],
    'quant_bits': {'input': 8, 'weight': 8},
    'op_types': ['Conv2d']
}, {
    'quant_types': ['output'],
    'quant_bits': {'output': 8},
    'op_types': ['ReLU']
}, {
    'quant_types': ['input', 'weight'],
    'quant_bits': {'input': 8, 'weight': 8},
    'op_names': ['fc1', 'fc2']
}]

finetuning the model by using QAT

from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
dummy_input = torch.rand(32, 1, 28, 28).to(device)
quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input)
quantizer.compress()

Out:

TorchModel(
  (conv1): QuantizerModuleWrapper(
    (module): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  )
  (conv2): QuantizerModuleWrapper(
    (module): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  )
  (fc1): QuantizerModuleWrapper(
    (module): Linear(in_features=256, out_features=120, bias=True)
  )
  (fc2): QuantizerModuleWrapper(
    (module): Linear(in_features=120, out_features=84, bias=True)
  )
  (fc3): Linear(in_features=84, out_features=10, bias=True)
  (relu1): QuantizerModuleWrapper(
    (module): ReLU()
  )
  (relu2): QuantizerModuleWrapper(
    (module): ReLU()
  )
  (relu3): QuantizerModuleWrapper(
    (module): ReLU()
  )
  (relu4): QuantizerModuleWrapper(
    (module): ReLU()
  )
  (pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)

The model has now been wrapped, and quantization targets ('quant_types' setting in config_list) will be quantized & dequantized for simulated quantization in the wrapped layers. QAT is a training-aware quantizer, it will update scale and zero point during training.

for epoch in range(3):
    trainer(model, optimizer, criterion)
    evaluator(model)

Out:

Average test loss: 0.1542, Accuracy: 9529/10000 (95%)
Average test loss: 0.1133, Accuracy: 9664/10000 (97%)
Average test loss: 0.0919, Accuracy: 9726/10000 (97%)

export model and get calibration_config

model_path = "./log/mnist_model.pth"
calibration_path = "./log/mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)

print("calibration_config: ", calibration_config)

Out:

calibration_config:  {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0031], device='cuda:0'), 'weight_zero_point': tensor([76.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0018], device='cuda:0'), 'weight_zero_point': tensor([113.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 12.42452621459961}, 'fc1': {'weight_bits': 8, 'weight_scale': tensor([0.0011], device='cuda:0'), 'weight_zero_point': tensor([124.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 31.650196075439453}, 'fc2': {'weight_bits': 8, 'weight_scale': tensor([0.0013], device='cuda:0'), 'weight_zero_point': tensor([122.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 25.805370330810547}, 'relu1': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 12.499907493591309}, 'relu2': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 32.0243034362793}, 'relu3': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 26.491384506225586}, 'relu4': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 17.662996292114258}}

build tensorRT engine to make a real speedup, for more information about speedup, please refer SpeedUp Model with Calibration Config.

from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
input_shape = (32, 1, 28, 28)
engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=32)
engine.compress()
test_trt(engine)

Out:

Loss: 0.09358334274291992  Accuracy: 97.21%
Inference elapsed_time (whole dataset): 0.04445981979370117s

Total running time of the script: ( 1 minutes 36.499 seconds)

Gallery generated by Sphinx-Gallery