备注
Click here to download the full example code
SpeedUp Model with Calibration Config¶
Introduction¶
Deep learning network has been computational intensive and memory intensive which increases the difficulty of deploying deep neural network model. Quantization is a fundamental technology which is widely used to reduce memory footprint and speedup inference process. Many frameworks begin to support quantization, but few of them support mixed precision quantization and get real speedup. Frameworks like HAQ: Hardware-Aware Automated Quantization with Mixed Precision, only support simulated mixed precision quantization which will not speedup the inference process. To get real speedup of mixed precision quantization and help people get the real feedback from hardware, we design a general framework with simple interface to allow NNI quantization algorithms to connect different DL model optimization backends (e.g., TensorRT, NNFusion), which gives users an end-to-end experience that after quantizing their model with quantization algorithms, the quantized model can be directly speeded up with the connected optimization backend. NNI connects TensorRT at this stage, and will support more backends in the future.
Design and Implementation¶
To support speeding up mixed precision quantization, we divide framework into two part, frontend and backend. Frontend could be popular training frameworks such as PyTorch, TensorFlow etc. Backend could be inference framework for different hardwares, such as TensorRT. At present, we support PyTorch as frontend and TensorRT as backend. To convert PyTorch model to TensorRT engine, we leverage onnx as intermediate graph representation. In this way, we convert PyTorch model to onnx model, then TensorRT parse onnx model to generate inference engine.
Quantization aware training combines NNI quantization algorithm 'QAT' and NNI quantization speedup tool. Users should set config to train quantized model using QAT algorithm(please refer to NNI Quantization Algorithms ). After quantization aware training, users can get new config with calibration parameters and model with quantized weight. By passing new config and model to quantization speedup tool, users can get real mixed precision speedup engine to do inference.
After getting mixed precision engine, users can do inference with input data.
Note
Recommend using "cpu"(host) as data device(for both inference data and calibration data) since data should be on host initially and it will be transposed to device before inference. If data type is not "cpu"(host), this tool will transpose it to "cpu" which may increases unnecessary overhead.
User can also do post-training quantization leveraging TensorRT directly(need to provide calibration dataset).
Not all op types are supported right now. At present, NNI supports Conv, Linear, Relu and MaxPool. More op types will be supported in the following release.
Prerequisite¶
CUDA version >= 11.0
TensorRT version >= 7.2
Note
If you haven't installed TensorRT before or use the old version, please refer to TensorRT Installation Guide
Usage¶
import torch
import torch.nn.functional as F
from torch.optim import SGD
from nni_assets.compression.mnist_model import TorchModel, device, trainer, evaluator, test_trt
config_list = [{
'quant_types': ['input', 'weight'],
'quant_bits': {'input': 8, 'weight': 8},
'op_types': ['Conv2d']
}, {
'quant_types': ['output'],
'quant_bits': {'output': 8},
'op_types': ['ReLU']
}, {
'quant_types': ['input', 'weight'],
'quant_bits': {'input': 8, 'weight': 8},
'op_names': ['fc1', 'fc2']
}]
model = TorchModel().to(device)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = F.nll_loss
dummy_input = torch.rand(32, 1, 28, 28).to(device)
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
quantizer = QAT_Quantizer(model, config_list, optimizer, dummy_input)
quantizer.compress()
Out:
TorchModel(
(conv1): QuantizerModuleWrapper(
(module): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
)
(conv2): QuantizerModuleWrapper(
(module): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
)
(fc1): QuantizerModuleWrapper(
(module): Linear(in_features=256, out_features=120, bias=True)
)
(fc2): QuantizerModuleWrapper(
(module): Linear(in_features=120, out_features=84, bias=True)
)
(fc3): Linear(in_features=84, out_features=10, bias=True)
(relu1): QuantizerModuleWrapper(
(module): ReLU()
)
(relu2): QuantizerModuleWrapper(
(module): ReLU()
)
(relu3): QuantizerModuleWrapper(
(module): ReLU()
)
(relu4): QuantizerModuleWrapper(
(module): ReLU()
)
(pool1): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
(pool2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
finetuning the model by using QAT
for epoch in range(3):
trainer(model, optimizer, criterion)
evaluator(model)
Out:
Average test loss: 0.6058, Accuracy: 8534/10000 (85%)
Average test loss: 0.1585, Accuracy: 9508/10000 (95%)
Average test loss: 0.0920, Accuracy: 9717/10000 (97%)
export model and get calibration_config
import os
os.makedirs('log', exist_ok=True)
model_path = "./log/mnist_model.pth"
calibration_path = "./log/mnist_calibration.pth"
calibration_config = quantizer.export_model(model_path, calibration_path)
print("calibration_config: ", calibration_config)
Out:
calibration_config: {'conv1': {'weight_bits': 8, 'weight_scale': tensor([0.0029], device='cuda:0'), 'weight_zero_point': tensor([97.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': -0.4242129623889923, 'tracked_max_input': 2.821486711502075}, 'conv2': {'weight_bits': 8, 'weight_scale': tensor([0.0017], device='cuda:0'), 'weight_zero_point': tensor([115.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 7.800363063812256}, 'fc1': {'weight_bits': 8, 'weight_scale': tensor([0.0010], device='cuda:0'), 'weight_zero_point': tensor([121.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 13.914573669433594}, 'fc2': {'weight_bits': 8, 'weight_scale': tensor([0.0012], device='cuda:0'), 'weight_zero_point': tensor([125.], device='cuda:0'), 'input_bits': 8, 'tracked_min_input': 0.0, 'tracked_max_input': 11.657418251037598}, 'relu1': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 7.897384166717529}, 'relu2': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 14.337020874023438}, 'relu3': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 11.884227752685547}, 'relu4': {'output_bits': 8, 'tracked_min_output': 0.0, 'tracked_max_output': 9.330422401428223}}
build tensorRT engine to make a real speedup
from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT
input_shape = (32, 1, 28, 28)
engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=32)
engine.compress()
test_trt(engine)
Out:
Loss: 0.09235906448364258 Accuracy: 97.19%
Inference elapsed_time (whole dataset): 0.03632998466491699s
Note that NNI also supports post-training quantization directly, please refer to complete examples for detail.
For complete examples please refer to the code.
For more parameters about the class 'TensorRTModelSpeedUp', you can refer to Model Compression API Reference.
Mnist test¶
on one GTX2080 GPU,
input tensor: torch.randn(128, 1, 28, 28)
quantization strategy |
Latency |
accuracy |
---|---|---|
all in 32bit |
0.001199961 |
96% |
mixed precision(average bit 20.4) |
0.000753688 |
96% |
all in 8bit |
0.000229869 |
93.7% |
Cifar10 resnet18 test (train one epoch)¶
on one GTX2080 GPU,
input tensor: torch.randn(128, 3, 32, 32)
quantization strategy |
Latency |
accuracy |
---|---|---|
all in 32bit |
0.003286268 |
54.21% |
mixed precision(average bit 11.55) |
0.001358022 |
54.78% |
all in 8bit |
0.000859139 |
52.81% |
Total running time of the script: ( 1 minutes 13.658 seconds)