Quantization Quickstart

Quantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.

In NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported. Here we use QATQuantizer as an example to show the usage of quantization in NNI.


In this tutorial, we use a simple model and pre-train on MNIST dataset. If you are familiar with defining a model and training in pytorch, you can skip directly to Quantizing Model.

import time
from typing import Callable, Union, Union

import torch
import torch.nn.functional as F
from torch.optim import Optimizer, SGD
from torch.utils.data import DataLoader
from torch import Tensor

from nni.common.types import SCHEDULER

Define the model

class Mnist(torch.nn.Module):
    def __init__(self):
        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
        self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
        self.fc2 = torch.nn.Linear(500, 10)
        self.relu1 = torch.nn.ReLU6()
        self.relu2 = torch.nn.ReLU6()
        self.relu3 = torch.nn.ReLU6()
        self.max_pool1 = torch.nn.MaxPool2d(2, 2)
        self.max_pool2 = torch.nn.MaxPool2d(2, 2)
        self.batchnorm1 = torch.nn.BatchNorm2d(20)

    def forward(self, x):
        x = self.relu1(self.batchnorm1(self.conv1(x)))
        x = self.max_pool1(x)
        x = self.relu2(self.conv2(x))
        x = self.max_pool2(x)
        x = x.view(-1, 4 * 4 * 50)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

Create training and evaluation dataloader

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

MNIST(root='data/mnist', train=True, download=True)
MNIST(root='data/mnist', train=False, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(root='data/mnist', train=True, transform=transform)
train_dataloader = DataLoader(mnist_train, batch_size=64)
mnist_test = MNIST(root='data/mnist', train=False, transform=transform)
test_dataloader = DataLoader(mnist_test, batch_size=1000)

Define training and evaluation functions

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def training_step(batch, model) -> Tensor:
    x, y = batch[0].to(device), batch[1].to(device)
    logits = model(x)
    loss: torch.Tensor = F.nll_loss(logits, y)
    return loss

def training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[SCHEDULER, None] = None,
                   max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
    max_epochs = max_epochs if max_epochs else 1 if max_steps is None else 100
    current_steps = 0

    # training
    for epoch in range(max_epochs):
        print(f'Epoch {epoch} start!')
        for batch in train_dataloader:
            loss = training_step(batch, model)
            current_steps += 1
            if max_steps and current_steps == max_steps:
        if scheduler is not None:

def evaluating_model(model: torch.nn.Module):
    # testing
    correct = 0
    with torch.no_grad():
        for x, y in test_dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += preds.eq(y.view_as(preds)).sum().item()
    return correct / len(mnist_test)

Pre-train and evaluate the model on MNIST dataset

model = Mnist().to(device)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

start = time.time()
training_model(model, optimizer, training_step, None, None, 5)
print(f'pure training 5 epochs: {time.time() - start}s')
start = time.time()
acc = evaluating_model(model)
print(f'pure evaluating: {time.time() - start}s    Acc.: {acc}')
Epoch 0 start!
Epoch 1 start!
Epoch 2 start!
Epoch 3 start!
Epoch 4 start!
pure training 5 epochs: 71.90893840789795s
pure evaluating: 1.6302893161773682s    Acc.: 0.9908

Quantizing Model

Initialize a config_list. Detailed about how to write config_list please refer Config Specification.

import nni
from nni.contrib.compression.quantization import QATQuantizer
from nni.contrib.compression.utils import TorchEvaluator

optimizer = nni.trace(SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
evaluator = TorchEvaluator(training_model, optimizer, training_step)  # type: ignore

config_list = [{
    'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],
    'target_names': ['_input_', 'weight', '_output_'],
    'quant_dtype': 'int8',
    'quant_scheme': 'affine',
    'granularity': 'default',
    'op_names': ['relu1', 'relu2', 'relu3'],
    'target_names': ['_output_'],
    'quant_dtype': 'int8',
    'quant_scheme': 'affine',
    'granularity': 'default',

quantizer = QATQuantizer(model, config_list, evaluator, len(train_dataloader))
real_input = next(iter(train_dataloader))[0].to(device)

start = time.time()
_, calibration_config = quantizer.compress(None, max_epochs=5)
print(f'pure training 5 epochs: {time.time() - start}s')

start = time.time()
acc = evaluating_model(model)
print(f'quantization evaluating: {time.time() - start}s    Acc.: {acc}')
Epoch 0 start!
Epoch 1 start!
Epoch 2 start!
Epoch 3 start!
Epoch 4 start!
pure training 5 epochs: 117.75990748405457s
defaultdict(<class 'dict'>, {'fc2': {'weight': {'scale': tensor(0.0020), 'zero_point': tensor(-8.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.2640), 'tracked_min': tensor(-0.2319)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.1541), 'zero_point': tensor(-39.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(25.6346), 'tracked_min': tensor(-13.5170)}}, 'conv1': {'weight': {'scale': tensor(0.0023), 'zero_point': tensor(-12.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.3128), 'tracked_min': tensor(-0.2606)}, '_input_0': {'scale': tensor(0.0128), 'zero_point': tensor(-94.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(2.8215), 'tracked_min': tensor(-0.4242)}, '_output_0': {'scale': tensor(0.0265), 'zero_point': tensor(-5.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(3.4957), 'tracked_min': tensor(-3.2373)}}, 'fc1': {'weight': {'scale': tensor(0.0007), 'zero_point': tensor(3.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.0894), 'tracked_min': tensor(-0.0943)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.0678), 'zero_point': tensor(-8.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(9.1579), 'tracked_min': tensor(-8.0707)}}, 'conv2': {'weight': {'scale': tensor(0.0012), 'zero_point': tensor(-35.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.1927), 'tracked_min': tensor(-0.1097)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(5.9995), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.0893), 'zero_point': tensor(2.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(11.1702), 'tracked_min': tensor(-11.5212)}}, 'relu3': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}}, 'relu2': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}}, 'relu1': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(5.9996), 'tracked_min': tensor(0.)}}})
quantization evaluating: 1.6024222373962402s    Acc.: 0.9915

Total running time of the script: ( 3 minutes 22.673 seconds)

Gallery generated by Sphinx-Gallery