Quantization Quickstart

Quantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.

In NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported. Here we use QATQuantizer as an example to show the usage of quantization in NNI.

Preparation

In this tutorial, we use a simple model and pre-train on MNIST dataset. If you are familiar with defining a model and training in pytorch, you can skip directly to Quantizing Model.

import time
from typing import Callable, Union, Union

import torch
import torch.nn.functional as F
from torch.optim import Optimizer, SGD
from torch.utils.data import DataLoader
from torch import Tensor

from nni.common.types import SCHEDULER

Define the model

class Mnist(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
        self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
        self.fc2 = torch.nn.Linear(500, 10)
        self.relu1 = torch.nn.ReLU6()
        self.relu2 = torch.nn.ReLU6()
        self.relu3 = torch.nn.ReLU6()
        self.max_pool1 = torch.nn.MaxPool2d(2, 2)
        self.max_pool2 = torch.nn.MaxPool2d(2, 2)
        self.batchnorm1 = torch.nn.BatchNorm2d(20)

    def forward(self, x):
        x = self.relu1(self.batchnorm1(self.conv1(x)))
        x = self.max_pool1(x)
        x = self.relu2(self.conv2(x))
        x = self.max_pool2(x)
        x = x.view(-1, 4 * 4 * 50)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

Create training and evaluation dataloader

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

MNIST(root='data/mnist', train=True, download=True)
MNIST(root='data/mnist', train=False, download=True)
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(root='data/mnist', train=True, transform=transform)
train_dataloader = DataLoader(mnist_train, batch_size=64)
mnist_test = MNIST(root='data/mnist', train=False, transform=transform)
test_dataloader = DataLoader(mnist_test, batch_size=1000)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]
100%|##########| 9912422/9912422 [00:00<00:00, 110174318.21it/s]
Extracting data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]
100%|##########| 28881/28881 [00:00<00:00, 91839040.05it/s]
Extracting data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]
100%|##########| 1648877/1648877 [00:00<00:00, 26703211.30it/s]
Extracting data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]
100%|##########| 4542/4542 [00:00<00:00, 63081221.09it/s]
Extracting data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw

Define training and evaluation functions

device = "cuda:0" if torch.cuda.is_available() else "cpu"


def training_step(batch, model) -> Tensor:
    x, y = batch[0].to(device), batch[1].to(device)
    logits = model(x)
    loss: torch.Tensor = F.nll_loss(logits, y)
    return loss


def training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[SCHEDULER, None] = None,
                   max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
    model.train()
    max_epochs = max_epochs if max_epochs else 1 if max_steps is None else 100
    current_steps = 0

    # training
    for epoch in range(max_epochs):
        print(f'Epoch {epoch} start!')
        for batch in train_dataloader:
            optimizer.zero_grad()
            loss = training_step(batch, model)
            loss.backward()
            optimizer.step()
            current_steps += 1
            if max_steps and current_steps == max_steps:
                return
        if scheduler is not None:
            scheduler.step()


def evaluating_model(model: torch.nn.Module):
    model.eval()
    # testing
    correct = 0
    with torch.no_grad():
        for x, y in test_dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            correct += preds.eq(y.view_as(preds)).sum().item()
    return correct / len(mnist_test)

Pre-train and evaluate the model on MNIST dataset

model = Mnist().to(device)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

start = time.time()
training_model(model, optimizer, training_step, None, None, 5)
print(f'pure training 5 epochs: {time.time() - start}s')
start = time.time()
acc = evaluating_model(model)
print(f'pure evaluating: {time.time() - start}s    Acc.: {acc}')
Epoch 0 start!
Epoch 1 start!
Epoch 2 start!
Epoch 3 start!
Epoch 4 start!
pure training 5 epochs: 62.24345350265503s
pure evaluating: 1.5607831478118896s    Acc.: 0.9906

Quantizing Model

Initialize a config_list. Detailed about how to write config_list please refer Config Specification.

import nni
from nni.compression.quantization import QATQuantizer
from nni.compression.utils import TorchEvaluator


optimizer = nni.trace(SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
evaluator = TorchEvaluator(training_model, optimizer, training_step)  # type: ignore

config_list = [{
    'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],
    'target_names': ['_input_', 'weight', '_output_'],
    'quant_dtype': 'int8',
    'quant_scheme': 'affine',
    'granularity': 'default',
},{
    'op_names': ['relu1', 'relu2', 'relu3'],
    'target_names': ['_output_'],
    'quant_dtype': 'int8',
    'quant_scheme': 'affine',
    'granularity': 'default',
}]

quantizer = QATQuantizer(model, config_list, evaluator, len(train_dataloader))
real_input = next(iter(train_dataloader))[0].to(device)
quantizer.track_forward(real_input)

start = time.time()
_, calibration_config = quantizer.compress(None, max_epochs=5)
print(f'pure training 5 epochs: {time.time() - start}s')

print(calibration_config)
start = time.time()
acc = evaluating_model(model)
print(f'quantization evaluating: {time.time() - start}s    Acc.: {acc}')
Epoch 0 start!
Epoch 1 start!
Epoch 2 start!
Epoch 3 start!
Epoch 4 start!
pure training 5 epochs: 94.30406522750854s
defaultdict(<class 'dict'>, {'fc1': {'weight': {'scale': tensor(0.0007), 'zero_point': tensor(6.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.0897), 'tracked_min': tensor(-0.0992)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.0648), 'zero_point': tensor(3.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(8.0606), 'tracked_min': tensor(-8.4004)}}, 'fc2': {'weight': {'scale': tensor(0.0018), 'zero_point': tensor(-5.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.2388), 'tracked_min': tensor(-0.2198)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.1514), 'zero_point': tensor(-35.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(24.4862), 'tracked_min': tensor(-13.9780)}}, 'conv1': {'weight': {'scale': tensor(0.0027), 'zero_point': tensor(11.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.3176), 'tracked_min': tensor(-0.3750)}, '_input_0': {'scale': tensor(0.0128), 'zero_point': tensor(-94.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(2.8215), 'tracked_min': tensor(-0.4242)}, '_output_0': {'scale': tensor(0.0261), 'zero_point': tensor(4.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(3.2271), 'tracked_min': tensor(-3.4134)}}, 'conv2': {'weight': {'scale': tensor(0.0011), 'zero_point': tensor(-24.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.1707), 'tracked_min': tensor(-0.1165)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(5.9999), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.0900), 'zero_point': tensor(1.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(11.3434), 'tracked_min': tensor(-11.5140)}}, 'relu2': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}}, 'relu1': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.0000), 'tracked_min': tensor(0.)}}, 'relu3': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}}})
quantization evaluating: 1.3835649490356445s    Acc.: 0.9912

Total running time of the script: ( 2 minutes 40.255 seconds)

Gallery generated by Sphinx-Gallery