# Quantization Quickstart¶

Quantization reduces model size and speeds up inference time by reducing the number of bits required to represent weights or activations.

In NNI, both post-training quantization algorithms and quantization-aware training algorithms are supported. Here we use QATQuantizer as an example to show the usage of quantization in NNI.

## Preparation¶

In this tutorial, we use a simple model and pre-train on MNIST dataset. If you are familiar with defining a model and training in pytorch, you can skip directly to Quantizing Model.

```import time
from typing import Callable, Union, Union

import torch
import torch.nn.functional as F
from torch.optim import Optimizer, SGD
from torch import Tensor

from nni.common.types import SCHEDULER
```

Define the model

```class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
self.max_pool2 = torch.nn.MaxPool2d(2, 2)
self.batchnorm1 = torch.nn.BatchNorm2d(20)

def forward(self, x):
x = self.relu1(self.batchnorm1(self.conv1(x)))
x = self.max_pool1(x)
x = self.relu2(self.conv2(x))
x = self.max_pool2(x)
x = x.view(-1, 4 * 4 * 50)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
```

```from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
mnist_train = MNIST(root='data/mnist', train=True, transform=transform)
mnist_test = MNIST(root='data/mnist', train=False, transform=transform)
```

Define training and evaluation functions

```device = "cuda:0" if torch.cuda.is_available() else "cpu"

def training_step(batch, model) -> Tensor:
x, y = batch[0].to(device), batch[1].to(device)
logits = model(x)
loss: torch.Tensor = F.nll_loss(logits, y)
return loss

def training_model(model: torch.nn.Module, optimizer: Optimizer, training_step: Callable, scheduler: Union[SCHEDULER, None] = None,
max_steps: Union[int, None] = None, max_epochs: Union[int, None] = None):
model.train()
max_epochs = max_epochs if max_epochs else 1 if max_steps is None else 100
current_steps = 0

# training
for epoch in range(max_epochs):
print(f'Epoch {epoch} start!')
loss = training_step(batch, model)
loss.backward()
optimizer.step()
current_steps += 1
if max_steps and current_steps == max_steps:
return
if scheduler is not None:
scheduler.step()

def evaluating_model(model: torch.nn.Module):
model.eval()
# testing
correct = 0
x, y = x.to(device), y.to(device)
logits = model(x)
preds = torch.argmax(logits, dim=1)
correct += preds.eq(y.view_as(preds)).sum().item()
return correct / len(mnist_test)
```

Pre-train and evaluate the model on MNIST dataset

```model = Mnist().to(device)
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

start = time.time()
training_model(model, optimizer, training_step, None, None, 5)
print(f'pure training 5 epochs: {time.time() - start}s')
start = time.time()
acc = evaluating_model(model)
print(f'pure evaluating: {time.time() - start}s    Acc.: {acc}')
```
```Epoch 0 start!
Epoch 1 start!
Epoch 2 start!
Epoch 3 start!
Epoch 4 start!
pure training 5 epochs: 71.90893840789795s
pure evaluating: 1.6302893161773682s    Acc.: 0.9908
```

## Quantizing Model¶

Initialize a config_list. Detailed about how to write `config_list` please refer Config Specification.

```import nni
from nni.contrib.compression.quantization import QATQuantizer
from nni.contrib.compression.utils import TorchEvaluator

optimizer = nni.trace(SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
evaluator = TorchEvaluator(training_model, optimizer, training_step)  # type: ignore

config_list = [{
'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],
'target_names': ['_input_', 'weight', '_output_'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
},{
'op_names': ['relu1', 'relu2', 'relu3'],
'target_names': ['_output_'],
'quant_dtype': 'int8',
'quant_scheme': 'affine',
'granularity': 'default',
}]

quantizer = QATQuantizer(model, config_list, evaluator, len(train_dataloader))
quantizer.track_forward(real_input)

start = time.time()
_, calibration_config = quantizer.compress(None, max_epochs=5)
print(f'pure training 5 epochs: {time.time() - start}s')

print(calibration_config)
start = time.time()
acc = evaluating_model(model)
print(f'quantization evaluating: {time.time() - start}s    Acc.: {acc}')
```
```Epoch 0 start!
Epoch 1 start!
Epoch 2 start!
Epoch 3 start!
Epoch 4 start!
pure training 5 epochs: 117.75990748405457s
defaultdict(<class 'dict'>, {'fc2': {'weight': {'scale': tensor(0.0020), 'zero_point': tensor(-8.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.2640), 'tracked_min': tensor(-0.2319)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.1541), 'zero_point': tensor(-39.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(25.6346), 'tracked_min': tensor(-13.5170)}}, 'conv1': {'weight': {'scale': tensor(0.0023), 'zero_point': tensor(-12.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.3128), 'tracked_min': tensor(-0.2606)}, '_input_0': {'scale': tensor(0.0128), 'zero_point': tensor(-94.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(2.8215), 'tracked_min': tensor(-0.4242)}, '_output_0': {'scale': tensor(0.0265), 'zero_point': tensor(-5.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(3.4957), 'tracked_min': tensor(-3.2373)}}, 'fc1': {'weight': {'scale': tensor(0.0007), 'zero_point': tensor(3.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.0894), 'tracked_min': tensor(-0.0943)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.0678), 'zero_point': tensor(-8.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(9.1579), 'tracked_min': tensor(-8.0707)}}, 'conv2': {'weight': {'scale': tensor(0.0012), 'zero_point': tensor(-35.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(0.1927), 'tracked_min': tensor(-0.1097)}, '_input_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(5.9995), 'tracked_min': tensor(0.)}, '_output_0': {'scale': tensor(0.0893), 'zero_point': tensor(2.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(11.1702), 'tracked_min': tensor(-11.5212)}}, 'relu3': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}}, 'relu2': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(6.), 'tracked_min': tensor(0.)}}, 'relu1': {'_output_0': {'scale': tensor(0.0236), 'zero_point': tensor(-127.), 'quant_dtype': 'int8', 'quant_scheme': 'affine', 'quant_bits': 8, 'tracked_max': tensor(5.9996), 'tracked_min': tensor(0.)}}})
quantization evaluating: 1.6024222373962402s    Acc.: 0.9915
```

Total running time of the script: ( 3 minutes 22.673 seconds)

Gallery generated by Sphinx-Gallery