1. Prepare model

[1]:
import torch
import torch.nn.functional as F

class NaiveModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
        self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
        self.fc2 = torch.nn.Linear(500, 10)
        self.relu1 = torch.nn.ReLU6()
        self.relu2 = torch.nn.ReLU6()
        self.relu3 = torch.nn.ReLU6()
        self.max_pool1 = torch.nn.MaxPool2d(2, 2)
        self.max_pool2 = torch.nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.relu1(self.conv1(x))
        x = self.max_pool1(x)
        x = self.relu2(self.conv2(x))
        x = self.max_pool2(x)
        x = x.view(-1, x.size()[1:].numel())
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)
[2]:
# define model, optimizer, criterion, data_loader, trainer, evaluator.

import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = NaiveModel().to(device)

optimizer = optim.Adadelta(model.parameters(), lr=1)

criterion = torch.nn.NLLLoss()

transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000)

def trainer(model, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def evaluator(model):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = 100 * correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), acc))

    return acc
[3]:
# pre-train model for 3 epoches.

scheduler = StepLR(optimizer, step_size=1, gamma=0.7)

for epoch in range(0, 3):
    trainer(model, optimizer, criterion, epoch)
    evaluator(model)
    scheduler.step()
Train Epoch: 0 [0/60000 (0%)]   Loss: 2.313423
Train Epoch: 0 [6400/60000 (11%)]       Loss: 0.091786
Train Epoch: 0 [12800/60000 (21%)]      Loss: 0.087317
Train Epoch: 0 [19200/60000 (32%)]      Loss: 0.036397
Train Epoch: 0 [25600/60000 (43%)]      Loss: 0.008173
Train Epoch: 0 [32000/60000 (53%)]      Loss: 0.047565
Train Epoch: 0 [38400/60000 (64%)]      Loss: 0.122448
Train Epoch: 0 [44800/60000 (75%)]      Loss: 0.036732
Train Epoch: 0 [51200/60000 (85%)]      Loss: 0.150135
Train Epoch: 0 [57600/60000 (96%)]      Loss: 0.109684

Test set: Average loss: 0.0457, Accuracy: 9857/10000 (99%)

Train Epoch: 1 [0/60000 (0%)]   Loss: 0.020650
Train Epoch: 1 [6400/60000 (11%)]       Loss: 0.091525
Train Epoch: 1 [12800/60000 (21%)]      Loss: 0.019602
Train Epoch: 1 [19200/60000 (32%)]      Loss: 0.027827
Train Epoch: 1 [25600/60000 (43%)]      Loss: 0.019414
Train Epoch: 1 [32000/60000 (53%)]      Loss: 0.007640
Train Epoch: 1 [38400/60000 (64%)]      Loss: 0.051296
Train Epoch: 1 [44800/60000 (75%)]      Loss: 0.012038
Train Epoch: 1 [51200/60000 (85%)]      Loss: 0.121057
Train Epoch: 1 [57600/60000 (96%)]      Loss: 0.015796

Test set: Average loss: 0.0302, Accuracy: 9902/10000 (99%)

Train Epoch: 2 [0/60000 (0%)]   Loss: 0.009903
Train Epoch: 2 [6400/60000 (11%)]       Loss: 0.062256
Train Epoch: 2 [12800/60000 (21%)]      Loss: 0.013844
Train Epoch: 2 [19200/60000 (32%)]      Loss: 0.014133
Train Epoch: 2 [25600/60000 (43%)]      Loss: 0.001051
Train Epoch: 2 [32000/60000 (53%)]      Loss: 0.006128
Train Epoch: 2 [38400/60000 (64%)]      Loss: 0.032162
Train Epoch: 2 [44800/60000 (75%)]      Loss: 0.007687
Train Epoch: 2 [51200/60000 (85%)]      Loss: 0.092295
Train Epoch: 2 [57600/60000 (96%)]      Loss: 0.006266

Test set: Average loss: 0.0259, Accuracy: 9920/10000 (99%)

[4]:
# show all op_name and op_type in the model.

[print('op_name: {}\nop_type: {}\n'.format(name, type(module))) for name, module in model.named_modules()]
op_name:
op_type: <class '__main__.NaiveModel'>

op_name: conv1
op_type: <class 'torch.nn.modules.conv.Conv2d'>

op_name: conv2
op_type: <class 'torch.nn.modules.conv.Conv2d'>

op_name: fc1
op_type: <class 'torch.nn.modules.linear.Linear'>

op_name: fc2
op_type: <class 'torch.nn.modules.linear.Linear'>

op_name: relu1
op_type: <class 'torch.nn.modules.activation.ReLU6'>

op_name: relu2
op_type: <class 'torch.nn.modules.activation.ReLU6'>

op_name: relu3
op_type: <class 'torch.nn.modules.activation.ReLU6'>

op_name: max_pool1
op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>

op_name: max_pool2
op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>

[4]:
[None, None, None, None, None, None, None, None, None, None]
[5]:
# show the weight size of `conv1`.

print(model.conv1.weight.data.size())
torch.Size([20, 1, 5, 5])
[6]:
# show the weight of `conv1`.

print(model.conv1.weight.data)
tensor([[[[ 1.5338e-01, -1.1766e-01, -2.6654e-01, -2.9445e-02, -1.4650e-01],
          [-1.8796e-01, -2.9882e-01,  6.9725e-02,  2.1561e-01,  6.5688e-02],
          [ 1.5274e-01, -9.8471e-03,  3.2303e-01,  1.3472e-03,  1.7235e-01],
          [ 1.1804e-01,  2.2535e-01, -8.3370e-02, -3.4553e-02, -1.2529e-01],
          [-6.6012e-02, -2.0272e-02, -1.8797e-01, -4.6882e-02, -8.3206e-02]]],


        [[[-1.2112e-01,  7.0756e-02,  5.0446e-02,  1.5156e-01, -2.7929e-02],
          [-1.9744e-01, -2.1336e-03,  7.2534e-02,  6.2336e-02,  1.6039e-01],
          [-6.7510e-02,  1.4636e-01,  7.1972e-02, -8.9118e-02, -4.0895e-02],
          [ 2.9499e-02,  2.0788e-01, -1.4989e-01,  1.1668e-01, -2.8503e-01],
          [ 8.1894e-02, -1.4489e-01, -4.2038e-02, -1.2794e-01, -5.0379e-02]]],


        [[[ 3.8332e-02, -1.4270e-01, -1.9585e-01,  2.2653e-01,  1.0104e-01],
          [-2.7956e-03, -1.4108e-01, -1.4694e-01, -1.3525e-01,  2.6959e-01],
          [ 1.9522e-01, -1.2281e-01, -1.9173e-01, -1.8910e-02,  3.1572e-03],
          [-1.0580e-01, -2.5239e-02, -5.8266e-02, -6.5815e-02,  6.6433e-02],
          [ 8.9601e-02,  7.1189e-02, -2.4255e-01,  1.5746e-01, -1.4708e-01]]],


        [[[-1.1963e-01, -1.7243e-01, -3.5174e-02,  1.4651e-01, -1.1675e-01],
          [-1.3518e-01,  1.2830e-02,  7.7188e-02,  2.1060e-01,  4.0924e-02],
          [-4.3364e-02, -1.9579e-01, -3.6559e-02, -6.9803e-02,  1.2380e-01],
          [ 7.7321e-02,  3.7590e-02,  8.2935e-02,  2.2878e-01,  2.7859e-03],
          [-1.3601e-01, -2.1167e-01, -2.3195e-01, -1.2524e-01,  1.0073e-01]]],


        [[[-2.7300e-01,  6.8470e-02,  2.8405e-02, -4.5879e-03, -1.3735e-01],
          [-8.9789e-02, -2.0209e-03,  5.0950e-03,  2.1633e-01,  2.5554e-01],
          [ 5.4389e-02,  1.2262e-01, -1.5514e-01, -1.0416e-01,  1.3606e-01],
          [-1.6794e-01, -2.8876e-02,  2.5900e-02, -2.4261e-02,  1.0923e-01],
          [ 5.2524e-03, -4.4625e-02, -2.1327e-01, -1.7211e-01, -4.4819e-04]]],


        [[[ 7.2378e-02,  1.5122e-01, -1.2964e-01,  4.9105e-02, -2.1639e-01],
          [ 3.6547e-02, -1.5518e-02,  3.2059e-02, -3.2820e-02,  6.1231e-02],
          [ 1.2514e-01,  8.0623e-02,  1.2686e-02, -1.0074e-01,  2.2836e-02],
          [-2.6842e-02,  2.5578e-02, -2.5877e-01, -1.7808e-01,  7.6966e-02],
          [-4.2424e-02,  4.7006e-02, -1.5486e-02, -4.2686e-02,  4.8482e-02]]],


        [[[ 1.3081e-01,  9.9530e-02, -1.4729e-01, -1.7665e-01, -1.9757e-01],
          [ 9.6603e-02,  2.2783e-02,  7.8402e-02, -2.8679e-02,  8.5252e-02],
          [-1.5310e-02,  1.1605e-01, -5.8300e-02,  2.4563e-02,  1.7488e-01],
          [ 6.5576e-02, -1.6325e-01, -1.1318e-01, -2.9251e-02,  6.2352e-02],
          [-1.9084e-03, -1.4005e-01, -1.2363e-01, -9.7985e-02, -2.0562e-01]]],


        [[[ 4.0772e-02, -8.2086e-02, -2.7555e-01, -3.2547e-01, -1.2226e-01],
          [-5.9877e-02,  9.8567e-02,  2.5186e-01, -1.0280e-01, -2.3416e-01],
          [ 8.5760e-02,  1.0896e-01,  1.4898e-01,  2.1579e-01,  8.5297e-02],
          [ 5.4720e-02, -1.7226e-01, -7.2518e-02,  6.7099e-03, -1.6011e-03],
          [-8.9944e-02,  1.7404e-01, -3.6985e-02,  1.8602e-01,  7.2353e-02]]],


        [[[ 1.6276e-02, -9.6439e-02, -9.6085e-02, -2.4267e-01, -1.8521e-01],
          [ 6.3310e-02,  1.7866e-01,  1.1694e-01, -1.4464e-01, -2.7711e-01],
          [-2.4514e-02,  2.2222e-01,  2.1053e-01, -1.4271e-01,  8.7045e-02],
          [-1.9207e-01, -5.4719e-02, -5.7775e-03, -1.0034e-05, -1.0923e-01],
          [-2.4006e-02,  2.3780e-02,  1.8988e-01,  2.4734e-01,  4.8097e-02]]],


        [[[ 1.1335e-01, -5.8451e-02,  5.2440e-02, -1.3223e-01, -2.5534e-02],
          [ 9.1323e-02, -6.0707e-02,  2.3524e-01,  2.4992e-01,  8.7842e-02],
          [ 2.9002e-02,  3.5379e-02, -5.9689e-02, -2.8363e-03,  1.8618e-01],
          [-2.9671e-01,  8.1830e-03,  1.1076e-01, -5.4118e-02, -6.1685e-02],
          [-1.7580e-01, -3.4534e-01, -3.9250e-01, -2.7569e-01, -2.6131e-01]]],


        [[[ 1.1586e-01, -7.5997e-02, -1.4614e-01,  4.8750e-02,  1.8097e-01],
          [-6.7027e-02, -1.4901e-01, -1.5614e-02, -1.0379e-02,  9.5526e-02],
          [-3.2333e-02, -1.5107e-01, -1.9498e-01,  1.0083e-01,  2.2328e-01],
          [-2.0692e-01, -6.3798e-02, -1.2524e-01,  1.9549e-01,  1.9682e-01],
          [-2.1494e-01,  1.0475e-01, -2.4858e-02, -9.7831e-02,  1.1551e-01]]],


        [[[ 6.3785e-02, -1.8044e-01, -1.0190e-01, -1.3588e-01,  8.5433e-02],
          [ 2.0675e-01,  3.3238e-02,  9.2437e-02,  1.1799e-01,  2.1111e-01],
          [-5.2138e-02,  1.5790e-01,  1.8151e-01,  8.0470e-02,  1.0131e-01],
          [-4.4786e-02,  1.1771e-01,  2.1706e-02, -1.2563e-01, -2.1142e-01],
          [-2.3589e-01, -2.1154e-01, -1.7890e-01, -2.7769e-01, -1.2512e-01]]],


        [[[ 1.9133e-01,  2.4711e-01,  1.0413e-01, -1.9187e-01, -3.0991e-01],
          [-1.2382e-01,  8.3641e-03, -5.6734e-02,  5.8376e-02,  2.2880e-02],
          [-3.1734e-01, -1.0637e-02, -5.5974e-02,  1.0676e-01, -1.1080e-02],
          [-2.2980e-01,  2.0486e-01,  1.0147e-01,  1.4484e-01,  5.2265e-02],
          [ 7.4410e-02,  2.2806e-02,  8.5137e-02, -2.1809e-01,  3.1704e-02]]],


        [[[-1.1006e-01, -2.5311e-01,  1.8925e-02,  1.0399e-02,  1.1951e-01],
          [-2.1116e-01,  1.8409e-01,  3.2172e-02,  1.5962e-01, -7.9457e-02],
          [ 1.1059e-01,  9.1966e-02,  1.0777e-01, -9.9132e-02, -4.4586e-02],
          [-8.7919e-02, -3.7283e-02,  9.1275e-02, -3.7412e-02,  3.8875e-02],
          [-4.3558e-02,  1.6196e-01, -4.7944e-03, -1.7560e-02, -1.2593e-01]]],


        [[[ 7.6976e-02, -3.8627e-02,  1.2610e-01,  1.1994e-01,  2.1706e-03],
          [ 7.4357e-02,  6.7929e-02,  3.1386e-02,  1.4606e-01,  2.1429e-01],
          [-2.6569e-01, -4.2631e-04, -3.6654e-02, -3.0967e-02, -9.4961e-02],
          [-2.0192e-01, -3.5423e-01, -2.5246e-01, -3.5092e-01, -2.4159e-01],
          [ 1.7636e-02,  1.3744e-01, -1.0306e-01,  8.8370e-02,  7.3258e-02]]],


        [[[ 2.0016e-01,  1.0956e-01, -5.9223e-02,  6.4871e-03, -2.4165e-01],
          [ 5.6283e-02,  1.7276e-01, -2.2316e-01, -1.6699e-01, -7.0742e-02],
          [ 2.6179e-01, -2.5102e-01, -2.0774e-01, -9.6413e-02,  3.4367e-02],
          [-9.1882e-02, -2.9195e-01, -8.7432e-02,  1.0144e-01, -2.0559e-02],
          [-2.5668e-01, -9.8016e-02,  1.1103e-01, -3.0233e-02,  1.1076e-01]]],


        [[[ 1.0027e-03, -5.7955e-02, -2.1339e-01, -1.6729e-01, -2.0870e-01],
          [ 4.2464e-02,  2.3177e-01, -6.1459e-02, -1.0905e-01,  1.7613e-02],
          [-1.2282e-01,  2.1762e-01, -1.3553e-02,  2.7476e-01,  1.6703e-01],
          [-5.6282e-02,  1.2731e-02,  1.0944e-01, -1.7347e-01,  4.4497e-02],
          [ 5.7346e-02, -5.4657e-02,  4.8718e-02, -2.6221e-02, -2.6933e-02]]],


        [[[ 6.7697e-02,  1.5692e-01,  2.7050e-01,  1.5936e-02,  1.7659e-01],
          [-2.8899e-02, -1.4866e-01,  3.1838e-02,  1.0903e-01,  1.2292e-01],
          [-1.3608e-01, -4.3198e-03, -9.8925e-02, -4.5599e-02,  1.3452e-01],
          [-5.1435e-02, -2.3815e-01, -2.4151e-01, -4.8556e-02,  1.3825e-01],
          [-1.2823e-01,  8.9324e-03, -1.5313e-01, -2.2933e-01, -3.4081e-02]]],


        [[[-1.8396e-01, -6.8774e-03, -1.6675e-01,  7.1980e-03,  1.9922e-02],
          [ 1.3416e-01, -1.1450e-01, -1.5277e-01, -6.5713e-02, -9.5435e-02],
          [ 1.5406e-01, -9.1235e-02, -1.0880e-01, -7.1603e-02, -9.5575e-02],
          [ 2.1772e-01,  8.4073e-02, -2.5264e-01, -2.1428e-01,  1.9537e-01],
          [ 1.3124e-01,  7.9532e-02, -2.4044e-01, -1.5717e-01,  1.6562e-01]]],


        [[[ 1.1849e-01, -5.0517e-03, -1.8900e-01,  1.8093e-02,  6.4660e-02],
          [-1.5309e-01, -2.0106e-01, -8.6551e-02,  5.2692e-03,  1.5448e-01],
          [-3.0727e-01,  4.9703e-02, -4.7637e-02,  2.9111e-01, -1.3173e-01],
          [-8.5167e-02, -1.3540e-01,  2.9235e-01,  3.7895e-03, -9.4651e-02],
          [-6.0694e-02,  9.6936e-02,  1.0533e-01, -6.1769e-02, -1.8086e-01]]]],
       device='cuda:0')

2. Prepare config_list for pruning

[7]:
# we will prune 50% weights in `conv1`.

config_list = [{
    'sparsity': 0.5,
    'op_types': ['Conv2d'],
    'op_names': ['conv1']
}]

3. Choose a pruner and pruning

[8]:
# use l1filter pruner to prune the model

from nni.algorithms.compression.pytorch.pruning import L1FilterPruner

# Note that if you use a compressor that need you to pass a optimizer,
# you need a new optimizer instead of you have used above, because NNI might modify the optimizer.
# And of course this modified optimizer can not be used in finetuning.
pruner = L1FilterPruner(model, config_list)
[9]:
# we can find the `conv1` has been wrapped, the origin `conv1` changes to `conv1.module`.
# the weight of conv1 will modify by `weight * mask` in `forward()`. The initial mask is a `ones_like(weight)` tensor.

[print('op_name: {}\nop_type: {}\n'.format(name, type(module))) for name, module in model.named_modules()]
op_name:
op_type: <class '__main__.NaiveModel'>

op_name: conv1
op_type: <class 'nni.compression.pytorch.compressor.PrunerModuleWrapper'>

op_name: conv1.module
op_type: <class 'torch.nn.modules.conv.Conv2d'>

op_name: conv2
op_type: <class 'torch.nn.modules.conv.Conv2d'>

op_name: fc1
op_type: <class 'torch.nn.modules.linear.Linear'>

op_name: fc2
op_type: <class 'torch.nn.modules.linear.Linear'>

op_name: relu1
op_type: <class 'torch.nn.modules.activation.ReLU6'>

op_name: relu2
op_type: <class 'torch.nn.modules.activation.ReLU6'>

op_name: relu3
op_type: <class 'torch.nn.modules.activation.ReLU6'>

op_name: max_pool1
op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>

op_name: max_pool2
op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>

[9]:
[None, None, None, None, None, None, None, None, None, None, None]
[10]:
# compress the model, the mask will be updated.

pruner.compress()
[10]:
NaiveModel(
  (conv1): PrunerModuleWrapper(
    (module): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  )
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
  (relu1): ReLU6()
  (relu2): ReLU6()
  (relu3): ReLU6()
  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
[11]:
# show the mask size of `conv1`

print(model.conv1.weight_mask.size())
torch.Size([20, 1, 5, 5])
[12]:
# show the mask of `conv1`

print(model.conv1.weight_mask)
tensor([[[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]],


        [[[1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1.]]],


        [[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]]]], device='cuda:0')
[13]:
# use a dummy input to apply the sparsify.

model(torch.rand(1, 1, 28, 28).to(device))

# the weights of `conv1` have been sparsified.

print(model.conv1.module.weight.data)
tensor([[[[ 1.5338e-01, -1.1766e-01, -2.6654e-01, -2.9445e-02, -1.4650e-01],
          [-1.8796e-01, -2.9882e-01,  6.9725e-02,  2.1561e-01,  6.5688e-02],
          [ 1.5274e-01, -9.8471e-03,  3.2303e-01,  1.3472e-03,  1.7235e-01],
          [ 1.1804e-01,  2.2535e-01, -8.3370e-02, -3.4553e-02, -1.2529e-01],
          [-6.6012e-02, -2.0272e-02, -1.8797e-01, -4.6882e-02, -8.3206e-02]]],


        [[[-0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],


        [[[ 3.8332e-02, -1.4270e-01, -1.9585e-01,  2.2653e-01,  1.0104e-01],
          [-2.7956e-03, -1.4108e-01, -1.4694e-01, -1.3525e-01,  2.6959e-01],
          [ 1.9522e-01, -1.2281e-01, -1.9173e-01, -1.8910e-02,  3.1572e-03],
          [-1.0580e-01, -2.5239e-02, -5.8266e-02, -6.5815e-02,  6.6433e-02],
          [ 8.9601e-02,  7.1189e-02, -2.4255e-01,  1.5746e-01, -1.4708e-01]]],


        [[[-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00]]],


        [[[-0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
          [ 0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],


        [[[ 4.0772e-02, -8.2086e-02, -2.7555e-01, -3.2547e-01, -1.2226e-01],
          [-5.9877e-02,  9.8567e-02,  2.5186e-01, -1.0280e-01, -2.3416e-01],
          [ 8.5760e-02,  1.0896e-01,  1.4898e-01,  2.1579e-01,  8.5297e-02],
          [ 5.4720e-02, -1.7226e-01, -7.2518e-02,  6.7099e-03, -1.6011e-03],
          [-8.9944e-02,  1.7404e-01, -3.6985e-02,  1.8602e-01,  7.2353e-02]]],


        [[[ 1.6276e-02, -9.6439e-02, -9.6085e-02, -2.4267e-01, -1.8521e-01],
          [ 6.3310e-02,  1.7866e-01,  1.1694e-01, -1.4464e-01, -2.7711e-01],
          [-2.4514e-02,  2.2222e-01,  2.1053e-01, -1.4271e-01,  8.7045e-02],
          [-1.9207e-01, -5.4719e-02, -5.7775e-03, -1.0034e-05, -1.0923e-01],
          [-2.4006e-02,  2.3780e-02,  1.8988e-01,  2.4734e-01,  4.8097e-02]]],


        [[[ 1.1335e-01, -5.8451e-02,  5.2440e-02, -1.3223e-01, -2.5534e-02],
          [ 9.1323e-02, -6.0707e-02,  2.3524e-01,  2.4992e-01,  8.7842e-02],
          [ 2.9002e-02,  3.5379e-02, -5.9689e-02, -2.8363e-03,  1.8618e-01],
          [-2.9671e-01,  8.1830e-03,  1.1076e-01, -5.4118e-02, -6.1685e-02],
          [-1.7580e-01, -3.4534e-01, -3.9250e-01, -2.7569e-01, -2.6131e-01]]],


        [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00]]],


        [[[ 6.3785e-02, -1.8044e-01, -1.0190e-01, -1.3588e-01,  8.5433e-02],
          [ 2.0675e-01,  3.3238e-02,  9.2437e-02,  1.1799e-01,  2.1111e-01],
          [-5.2138e-02,  1.5790e-01,  1.8151e-01,  8.0470e-02,  1.0131e-01],
          [-4.4786e-02,  1.1771e-01,  2.1706e-02, -1.2563e-01, -2.1142e-01],
          [-2.3589e-01, -2.1154e-01, -1.7890e-01, -2.7769e-01, -1.2512e-01]]],


        [[[ 1.9133e-01,  2.4711e-01,  1.0413e-01, -1.9187e-01, -3.0991e-01],
          [-1.2382e-01,  8.3641e-03, -5.6734e-02,  5.8376e-02,  2.2880e-02],
          [-3.1734e-01, -1.0637e-02, -5.5974e-02,  1.0676e-01, -1.1080e-02],
          [-2.2980e-01,  2.0486e-01,  1.0147e-01,  1.4484e-01,  5.2265e-02],
          [ 7.4410e-02,  2.2806e-02,  8.5137e-02, -2.1809e-01,  3.1704e-02]]],


        [[[-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],


        [[[ 7.6976e-02, -3.8627e-02,  1.2610e-01,  1.1994e-01,  2.1706e-03],
          [ 7.4357e-02,  6.7929e-02,  3.1386e-02,  1.4606e-01,  2.1429e-01],
          [-2.6569e-01, -4.2631e-04, -3.6654e-02, -3.0967e-02, -9.4961e-02],
          [-2.0192e-01, -3.5423e-01, -2.5246e-01, -3.5092e-01, -2.4159e-01],
          [ 1.7636e-02,  1.3744e-01, -1.0306e-01,  8.8370e-02,  7.3258e-02]]],


        [[[ 2.0016e-01,  1.0956e-01, -5.9223e-02,  6.4871e-03, -2.4165e-01],
          [ 5.6283e-02,  1.7276e-01, -2.2316e-01, -1.6699e-01, -7.0742e-02],
          [ 2.6179e-01, -2.5102e-01, -2.0774e-01, -9.6413e-02,  3.4367e-02],
          [-9.1882e-02, -2.9195e-01, -8.7432e-02,  1.0144e-01, -2.0559e-02],
          [-2.5668e-01, -9.8016e-02,  1.1103e-01, -3.0233e-02,  1.1076e-01]]],


        [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00]]],


        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],


        [[[-1.8396e-01, -6.8774e-03, -1.6675e-01,  7.1980e-03,  1.9922e-02],
          [ 1.3416e-01, -1.1450e-01, -1.5277e-01, -6.5713e-02, -9.5435e-02],
          [ 1.5406e-01, -9.1235e-02, -1.0880e-01, -7.1603e-02, -9.5575e-02],
          [ 2.1772e-01,  8.4073e-02, -2.5264e-01, -2.1428e-01,  1.9537e-01],
          [ 1.3124e-01,  7.9532e-02, -2.4044e-01, -1.5717e-01,  1.6562e-01]]],


        [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00]]]],
       device='cuda:0')
[14]:
# export the sparsified model state to './pruned_naive_mnist_l1filter.pth'.
# export the mask to './mask_naive_mnist_l1filter.pth'.

pruner.export_model(model_path='pruned_naive_mnist_l1filter.pth', mask_path='mask_naive_mnist_l1filter.pth')
[2021-07-26 22:26:05] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to pruned_naive_mnist_l1filter.pth
[2021-07-26 22:26:05] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to mask_naive_mnist_l1filter.pth

4. Speed Up

[15]:
# If you use a wrapped model, don't forget to unwrap it.

pruner._unwrap_model()

# the model has been unwrapped.

print(model)
NaiveModel(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
  (relu1): ReLU6()
  (relu2): ReLU6()
  (relu3): ReLU6()
  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
[16]:
from nni.compression.pytorch import ModelSpeedup

m_speedup = ModelSpeedup(model, dummy_input=torch.rand(10, 1, 28, 28).to(device), masks_file='mask_naive_mnist_l1filter.pth')
m_speedup.speedup_model()
<ipython-input-1-0f2a9eb92f42>:22: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  x = x.view(-1, x.size()[1:].numel())
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) start to speed up the model
[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) {'conv1': 1, 'conv2': 1}
[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) dim0 sparsity: 0.500000
[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) dim1 sparsity: 0.000000
[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) Dectected conv prune dim" 0
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) infer module masks...
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for conv1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for max_pool1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for conv2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for max_pool2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::view.9
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.jit_translate/MainThread) View Module output size: [-1, 800]
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for fc1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu3
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for fc2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::log_softmax.10
[2021-07-26 22:26:18] ERROR (nni.compression.pytorch.speedup.jit_translate/MainThread) aten::log_softmax is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for .aten::log_softmax.10
[2021-07-26 22:26:18] WARNING (nni.compression.pytorch.speedup.compressor/MainThread) Note: .aten::log_softmax.10 does not have corresponding mask inference object
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for fc2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the fc2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu3
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu3
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for fc1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the fc1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for .aten::view.9
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::view.9
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for max_pool2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the max_pool2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for conv2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the conv2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for max_pool1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the max_pool1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for conv1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the conv1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) resolve the mask conflict
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace compressed modules...
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: conv1, op_type: Conv2d)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu1, op_type: ReLU6)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: max_pool1, op_type: MaxPool2d)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: conv2, op_type: Conv2d)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu2, op_type: ReLU6)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: max_pool2, op_type: MaxPool2d)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::view.9, op_type: aten::view) which is func type
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: fc1, op_type: Linear)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace linear with new in_features: 800, out_features: 500
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu3, op_type: ReLU6)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: fc2, op_type: Linear)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace linear with new in_features: 500, out_features: 10
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::log_softmax.10, op_type: aten::log_softmax) which is func type
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) speedup done
[17]:
# the `conv1` has been replace from `Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))` to `Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))`
# and the following layer `conv2` has also changed because the input channel of `conv2` should aware the output channel of `conv1`.

print(model)
NaiveModel(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
  (relu1): ReLU6()
  (relu2): ReLU6()
  (relu3): ReLU6()
  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
[18]:
# finetune the model to recover the accuracy.

optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(0, 1):
    trainer(model, optimizer, criterion, epoch)
    evaluator(model)
Train Epoch: 0 [0/60000 (0%)]   Loss: 0.306930
Train Epoch: 0 [6400/60000 (11%)]       Loss: 0.045807
Train Epoch: 0 [12800/60000 (21%)]      Loss: 0.049293
Train Epoch: 0 [19200/60000 (32%)]      Loss: 0.031464
Train Epoch: 0 [25600/60000 (43%)]      Loss: 0.005392
Train Epoch: 0 [32000/60000 (53%)]      Loss: 0.005652
Train Epoch: 0 [38400/60000 (64%)]      Loss: 0.040619
Train Epoch: 0 [44800/60000 (75%)]      Loss: 0.016515
Train Epoch: 0 [51200/60000 (85%)]      Loss: 0.092886
Train Epoch: 0 [57600/60000 (96%)]      Loss: 0.041380

Test set: Average loss: 0.0257, Accuracy: 9917/10000 (99%)

5. Prepare config_list for quantization

[19]:
config_list = [{
    'quant_types': ['weight', 'input'],
    'quant_bits': {'weight': 8, 'input': 8},
    'op_names': ['conv1', 'conv2']
}]

6. Choose a quantizer and quantizing

[20]:
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer

quantizer = QAT_Quantizer(model, config_list, optimizer)
quantizer.compress()
[20]:
NaiveModel(
  (conv1): QuantizerModuleWrapper(
    (module): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  )
  (conv2): QuantizerModuleWrapper(
    (module): Conv2d(10, 50, kernel_size=(5, 5), stride=(1, 1))
  )
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
  (relu1): ReLU6()
  (relu2): ReLU6()
  (relu3): ReLU6()
  (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
[21]:
# finetune the model for calibration.

for epoch in range(0, 1):
    trainer(model, optimizer, criterion, epoch)
    evaluator(model)
Train Epoch: 0 [0/60000 (0%)]   Loss: 0.004960
Train Epoch: 0 [6400/60000 (11%)]       Loss: 0.036269
Train Epoch: 0 [12800/60000 (21%)]      Loss: 0.018744
Train Epoch: 0 [19200/60000 (32%)]      Loss: 0.021916
Train Epoch: 0 [25600/60000 (43%)]      Loss: 0.003095
Train Epoch: 0 [32000/60000 (53%)]      Loss: 0.003947
Train Epoch: 0 [38400/60000 (64%)]      Loss: 0.032094
Train Epoch: 0 [44800/60000 (75%)]      Loss: 0.017358
Train Epoch: 0 [51200/60000 (85%)]      Loss: 0.083886
Train Epoch: 0 [57600/60000 (96%)]      Loss: 0.040433

Test set: Average loss: 0.0247, Accuracy: 9917/10000 (99%)

[22]:
# export the sparsified model state to './quantized_naive_mnist_l1filter.pth'.
# export the calibration config to './calibration_naive_mnist_l1filter.pth'.

quantizer.export_model(model_path='quantized_naive_mnist_l1filter.pth', calibration_path='calibration_naive_mnist_l1filter.pth')
[2021-07-26 22:34:41] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to quantized_naive_mnist_l1filter.pth
[2021-07-26 22:34:41] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to calibration_naive_mnist_l1filter.pth
[22]:
{'conv1': {'weight_bit': 8,
  'tracked_min_input': -0.42417848110198975,
  'tracked_max_input': 2.8212687969207764},
 'conv2': {'weight_bit': 8,
  'tracked_min_input': 0.0,
  'tracked_max_input': 4.246923446655273}}

7. Speed Up

[ ]:
# speed up with tensorRT

engine = ModelSpeedupTensorRT(model, (32, 1, 28, 28), config=calibration_config, batchsize=32)
engine.compress()