Note
Click here to download the full example code
Quantization Quickstart¶
Here is a four-minute video to get you started with model quantization.
Quantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.
In NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported. Here we use QAT_Quantizer as an example to show the usage of quantization in NNI.
Preparation¶
In this tutorial, we use a simple model and pre-train on MNIST dataset. If you are familiar with defining a model and training in pytorch, you can skip directly to Quantizing Model.
import torch
import torch.nn.functional as F
from torch.optim import SGD
from nni_assets.compression.mnist_model import TorchModel, trainer, evaluator, device, test_trt
# define the model
model = TorchModel().to(device)
# define the optimizer and criterion for pre-training
optimizer = SGD(model.parameters(), 1e-2)
criterion = F.nll_loss
# pre-train and evaluate the model on MNIST dataset
for epoch in range(3):
trainer(model, optimizer, criterion)
evaluator(model)
Out:
Average test loss: 0.6440, Accuracy: 8230/10000 (82%)
Average test loss: 0.2512, Accuracy: 9272/10000 (93%)
Average test loss: 0.1569, Accuracy: 9542/10000 (95%)
Quantizing Model¶
Initialize a config_list.
Detailed about how to write config_list
please refer compression config specification.
config_list = [{
'quant_types': ['input', 'weight'],
'quant_bits': {'input': 8, 'weight': 8},
'op_types': ['Conv2d']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_types': ['ReLU']
}, {
'quant_types': ['input', 'weight'],
'quant_bits': {'input': 8, 'weight': 8},
'op_names': ['fc1', 'fc2']
}]
finetuning the model by using QAT
from nni.compression.pytorch.quantization import QAT_Quantizer
dummy_input = torch.rand(32, 1, 28, 28).to(device)
quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input)
quantizer.compress()
Out:
TorchModel(
(conv1): QuantizerModuleWrapper(
(module): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
)
(conv2): QuantizerModuleWrapper(
(module): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
)
(fc1): QuantizerModuleWrapper(
(module): Linear(in_features=256, out_features=120, bias=True)
)
(fc2): QuantizerModuleWrapper(
(module): Linear(in_features=120, out_features=84, bias=True)
)
(fc3): Linear(in_features=84, out_features=10, bias=True)
(relu1): QuantizerModuleWrapper(
(module): ReLU()
)
(relu2): QuantizerModuleWrapper(
(module): ReLU()
)
(relu3): QuantizerModuleWrapper(
(module): ReLU()
)
(relu4): QuantizerModuleWrapper(
(module): ReLU()
)
(pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
(pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
The model has now been wrapped, and quantization targets (‘quant_types’ setting in config_list) will be quantized & dequantized for simulated quantization in the wrapped layers. QAT is a training-aware quantizer, it will update scale and zero point during training.
for epoch in range(3):
trainer(model, optimizer, criterion)
evaluator(model)
Out:
Average test loss: 0.1209, Accuracy: 9629/10000 (96%)
Average test loss: 0.1032, Accuracy: 9696/10000 (97%)
Average test loss: 0.0909, Accuracy: 9736/10000 (97%)
export model and get calibration_config
model_path = "./log/mnist_model.pth"
calibration_path = "./log/mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)
print("calibration_config: ", calibration_config)
Out:
calibration_config: {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0032], device='cuda:0'), 'weight_zero_point': tensor([92.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0022], device='cuda:0'), 'weight_zero_point': tensor([110.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 11.599255561828613}, 'fc1': {'weight_bits': 8, 'weight_scale': tensor([0.0010], device='cuda:0'), 'weight_zero_point': tensor([113.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 26.364503860473633}, 'fc2': {'weight_bits': 8, 'weight_scale': tensor([0.0013], device='cuda:0'), 'weight_zero_point': tensor([124.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 26.364498138427734}, 'relu1': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 11.658699989318848}, 'relu2': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 26.645591735839844}, 'relu3': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 26.877971649169922}, 'relu4': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 16.9318904876709}}
build tensorRT engine to make a real speedup, for more information about speedup, please refer Speed Up Quantized Model with TensorRT.
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
input_shape = (32, 1, 28, 28)
engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=32)
engine.compress()
test_trt(engine)
Out:
Loss: 0.09197621383666992 Accuracy: 97.29%
Inference elapsed_time (whole dataset): 0.036701202392578125s
Total running time of the script: ( 1 minutes 46.013 seconds)