1. Prepare model¶
[1]:
import torch
import torch.nn.functional as F
class NaiveModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
self.relu1 = torch.nn.ReLU6()
self.relu2 = torch.nn.ReLU6()
self.relu3 = torch.nn.ReLU6()
self.max_pool1 = torch.nn.MaxPool2d(2, 2)
self.max_pool2 = torch.nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.max_pool1(x)
x = self.relu2(self.conv2(x))
x = self.max_pool2(x)
x = x.view(-1, x.size()[1:].numel())
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
[2]:
# define model, optimizer, criterion, data_loader, trainer, evaluator.
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NaiveModel().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=1)
criterion = torch.nn.NLLLoss()
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000)
def trainer(model, optimizer, criterion, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def evaluator(model):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset), acc))
return acc
[3]:
# pre-train model for 3 epoches.
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
for epoch in range(0, 3):
trainer(model, optimizer, criterion, epoch)
evaluator(model)
scheduler.step()
Train Epoch: 0 [0/60000 (0%)] Loss: 2.313423
Train Epoch: 0 [6400/60000 (11%)] Loss: 0.091786
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.087317
Train Epoch: 0 [19200/60000 (32%)] Loss: 0.036397
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.008173
Train Epoch: 0 [32000/60000 (53%)] Loss: 0.047565
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.122448
Train Epoch: 0 [44800/60000 (75%)] Loss: 0.036732
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.150135
Train Epoch: 0 [57600/60000 (96%)] Loss: 0.109684
Test set: Average loss: 0.0457, Accuracy: 9857/10000 (99%)
Train Epoch: 1 [0/60000 (0%)] Loss: 0.020650
Train Epoch: 1 [6400/60000 (11%)] Loss: 0.091525
Train Epoch: 1 [12800/60000 (21%)] Loss: 0.019602
Train Epoch: 1 [19200/60000 (32%)] Loss: 0.027827
Train Epoch: 1 [25600/60000 (43%)] Loss: 0.019414
Train Epoch: 1 [32000/60000 (53%)] Loss: 0.007640
Train Epoch: 1 [38400/60000 (64%)] Loss: 0.051296
Train Epoch: 1 [44800/60000 (75%)] Loss: 0.012038
Train Epoch: 1 [51200/60000 (85%)] Loss: 0.121057
Train Epoch: 1 [57600/60000 (96%)] Loss: 0.015796
Test set: Average loss: 0.0302, Accuracy: 9902/10000 (99%)
Train Epoch: 2 [0/60000 (0%)] Loss: 0.009903
Train Epoch: 2 [6400/60000 (11%)] Loss: 0.062256
Train Epoch: 2 [12800/60000 (21%)] Loss: 0.013844
Train Epoch: 2 [19200/60000 (32%)] Loss: 0.014133
Train Epoch: 2 [25600/60000 (43%)] Loss: 0.001051
Train Epoch: 2 [32000/60000 (53%)] Loss: 0.006128
Train Epoch: 2 [38400/60000 (64%)] Loss: 0.032162
Train Epoch: 2 [44800/60000 (75%)] Loss: 0.007687
Train Epoch: 2 [51200/60000 (85%)] Loss: 0.092295
Train Epoch: 2 [57600/60000 (96%)] Loss: 0.006266
Test set: Average loss: 0.0259, Accuracy: 9920/10000 (99%)
[4]:
# show all op_name and op_type in the model.
[print('op_name: {}\nop_type: {}\n'.format(name, type(module))) for name, module in model.named_modules()]
op_name:
op_type: <class '__main__.NaiveModel'>
op_name: conv1
op_type: <class 'torch.nn.modules.conv.Conv2d'>
op_name: conv2
op_type: <class 'torch.nn.modules.conv.Conv2d'>
op_name: fc1
op_type: <class 'torch.nn.modules.linear.Linear'>
op_name: fc2
op_type: <class 'torch.nn.modules.linear.Linear'>
op_name: relu1
op_type: <class 'torch.nn.modules.activation.ReLU6'>
op_name: relu2
op_type: <class 'torch.nn.modules.activation.ReLU6'>
op_name: relu3
op_type: <class 'torch.nn.modules.activation.ReLU6'>
op_name: max_pool1
op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>
op_name: max_pool2
op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>
[4]:
[None, None, None, None, None, None, None, None, None, None]
[5]:
# show the weight size of `conv1`.
print(model.conv1.weight.data.size())
torch.Size([20, 1, 5, 5])
[6]:
# show the weight of `conv1`.
print(model.conv1.weight.data)
tensor([[[[ 1.5338e-01, -1.1766e-01, -2.6654e-01, -2.9445e-02, -1.4650e-01],
[-1.8796e-01, -2.9882e-01, 6.9725e-02, 2.1561e-01, 6.5688e-02],
[ 1.5274e-01, -9.8471e-03, 3.2303e-01, 1.3472e-03, 1.7235e-01],
[ 1.1804e-01, 2.2535e-01, -8.3370e-02, -3.4553e-02, -1.2529e-01],
[-6.6012e-02, -2.0272e-02, -1.8797e-01, -4.6882e-02, -8.3206e-02]]],
[[[-1.2112e-01, 7.0756e-02, 5.0446e-02, 1.5156e-01, -2.7929e-02],
[-1.9744e-01, -2.1336e-03, 7.2534e-02, 6.2336e-02, 1.6039e-01],
[-6.7510e-02, 1.4636e-01, 7.1972e-02, -8.9118e-02, -4.0895e-02],
[ 2.9499e-02, 2.0788e-01, -1.4989e-01, 1.1668e-01, -2.8503e-01],
[ 8.1894e-02, -1.4489e-01, -4.2038e-02, -1.2794e-01, -5.0379e-02]]],
[[[ 3.8332e-02, -1.4270e-01, -1.9585e-01, 2.2653e-01, 1.0104e-01],
[-2.7956e-03, -1.4108e-01, -1.4694e-01, -1.3525e-01, 2.6959e-01],
[ 1.9522e-01, -1.2281e-01, -1.9173e-01, -1.8910e-02, 3.1572e-03],
[-1.0580e-01, -2.5239e-02, -5.8266e-02, -6.5815e-02, 6.6433e-02],
[ 8.9601e-02, 7.1189e-02, -2.4255e-01, 1.5746e-01, -1.4708e-01]]],
[[[-1.1963e-01, -1.7243e-01, -3.5174e-02, 1.4651e-01, -1.1675e-01],
[-1.3518e-01, 1.2830e-02, 7.7188e-02, 2.1060e-01, 4.0924e-02],
[-4.3364e-02, -1.9579e-01, -3.6559e-02, -6.9803e-02, 1.2380e-01],
[ 7.7321e-02, 3.7590e-02, 8.2935e-02, 2.2878e-01, 2.7859e-03],
[-1.3601e-01, -2.1167e-01, -2.3195e-01, -1.2524e-01, 1.0073e-01]]],
[[[-2.7300e-01, 6.8470e-02, 2.8405e-02, -4.5879e-03, -1.3735e-01],
[-8.9789e-02, -2.0209e-03, 5.0950e-03, 2.1633e-01, 2.5554e-01],
[ 5.4389e-02, 1.2262e-01, -1.5514e-01, -1.0416e-01, 1.3606e-01],
[-1.6794e-01, -2.8876e-02, 2.5900e-02, -2.4261e-02, 1.0923e-01],
[ 5.2524e-03, -4.4625e-02, -2.1327e-01, -1.7211e-01, -4.4819e-04]]],
[[[ 7.2378e-02, 1.5122e-01, -1.2964e-01, 4.9105e-02, -2.1639e-01],
[ 3.6547e-02, -1.5518e-02, 3.2059e-02, -3.2820e-02, 6.1231e-02],
[ 1.2514e-01, 8.0623e-02, 1.2686e-02, -1.0074e-01, 2.2836e-02],
[-2.6842e-02, 2.5578e-02, -2.5877e-01, -1.7808e-01, 7.6966e-02],
[-4.2424e-02, 4.7006e-02, -1.5486e-02, -4.2686e-02, 4.8482e-02]]],
[[[ 1.3081e-01, 9.9530e-02, -1.4729e-01, -1.7665e-01, -1.9757e-01],
[ 9.6603e-02, 2.2783e-02, 7.8402e-02, -2.8679e-02, 8.5252e-02],
[-1.5310e-02, 1.1605e-01, -5.8300e-02, 2.4563e-02, 1.7488e-01],
[ 6.5576e-02, -1.6325e-01, -1.1318e-01, -2.9251e-02, 6.2352e-02],
[-1.9084e-03, -1.4005e-01, -1.2363e-01, -9.7985e-02, -2.0562e-01]]],
[[[ 4.0772e-02, -8.2086e-02, -2.7555e-01, -3.2547e-01, -1.2226e-01],
[-5.9877e-02, 9.8567e-02, 2.5186e-01, -1.0280e-01, -2.3416e-01],
[ 8.5760e-02, 1.0896e-01, 1.4898e-01, 2.1579e-01, 8.5297e-02],
[ 5.4720e-02, -1.7226e-01, -7.2518e-02, 6.7099e-03, -1.6011e-03],
[-8.9944e-02, 1.7404e-01, -3.6985e-02, 1.8602e-01, 7.2353e-02]]],
[[[ 1.6276e-02, -9.6439e-02, -9.6085e-02, -2.4267e-01, -1.8521e-01],
[ 6.3310e-02, 1.7866e-01, 1.1694e-01, -1.4464e-01, -2.7711e-01],
[-2.4514e-02, 2.2222e-01, 2.1053e-01, -1.4271e-01, 8.7045e-02],
[-1.9207e-01, -5.4719e-02, -5.7775e-03, -1.0034e-05, -1.0923e-01],
[-2.4006e-02, 2.3780e-02, 1.8988e-01, 2.4734e-01, 4.8097e-02]]],
[[[ 1.1335e-01, -5.8451e-02, 5.2440e-02, -1.3223e-01, -2.5534e-02],
[ 9.1323e-02, -6.0707e-02, 2.3524e-01, 2.4992e-01, 8.7842e-02],
[ 2.9002e-02, 3.5379e-02, -5.9689e-02, -2.8363e-03, 1.8618e-01],
[-2.9671e-01, 8.1830e-03, 1.1076e-01, -5.4118e-02, -6.1685e-02],
[-1.7580e-01, -3.4534e-01, -3.9250e-01, -2.7569e-01, -2.6131e-01]]],
[[[ 1.1586e-01, -7.5997e-02, -1.4614e-01, 4.8750e-02, 1.8097e-01],
[-6.7027e-02, -1.4901e-01, -1.5614e-02, -1.0379e-02, 9.5526e-02],
[-3.2333e-02, -1.5107e-01, -1.9498e-01, 1.0083e-01, 2.2328e-01],
[-2.0692e-01, -6.3798e-02, -1.2524e-01, 1.9549e-01, 1.9682e-01],
[-2.1494e-01, 1.0475e-01, -2.4858e-02, -9.7831e-02, 1.1551e-01]]],
[[[ 6.3785e-02, -1.8044e-01, -1.0190e-01, -1.3588e-01, 8.5433e-02],
[ 2.0675e-01, 3.3238e-02, 9.2437e-02, 1.1799e-01, 2.1111e-01],
[-5.2138e-02, 1.5790e-01, 1.8151e-01, 8.0470e-02, 1.0131e-01],
[-4.4786e-02, 1.1771e-01, 2.1706e-02, -1.2563e-01, -2.1142e-01],
[-2.3589e-01, -2.1154e-01, -1.7890e-01, -2.7769e-01, -1.2512e-01]]],
[[[ 1.9133e-01, 2.4711e-01, 1.0413e-01, -1.9187e-01, -3.0991e-01],
[-1.2382e-01, 8.3641e-03, -5.6734e-02, 5.8376e-02, 2.2880e-02],
[-3.1734e-01, -1.0637e-02, -5.5974e-02, 1.0676e-01, -1.1080e-02],
[-2.2980e-01, 2.0486e-01, 1.0147e-01, 1.4484e-01, 5.2265e-02],
[ 7.4410e-02, 2.2806e-02, 8.5137e-02, -2.1809e-01, 3.1704e-02]]],
[[[-1.1006e-01, -2.5311e-01, 1.8925e-02, 1.0399e-02, 1.1951e-01],
[-2.1116e-01, 1.8409e-01, 3.2172e-02, 1.5962e-01, -7.9457e-02],
[ 1.1059e-01, 9.1966e-02, 1.0777e-01, -9.9132e-02, -4.4586e-02],
[-8.7919e-02, -3.7283e-02, 9.1275e-02, -3.7412e-02, 3.8875e-02],
[-4.3558e-02, 1.6196e-01, -4.7944e-03, -1.7560e-02, -1.2593e-01]]],
[[[ 7.6976e-02, -3.8627e-02, 1.2610e-01, 1.1994e-01, 2.1706e-03],
[ 7.4357e-02, 6.7929e-02, 3.1386e-02, 1.4606e-01, 2.1429e-01],
[-2.6569e-01, -4.2631e-04, -3.6654e-02, -3.0967e-02, -9.4961e-02],
[-2.0192e-01, -3.5423e-01, -2.5246e-01, -3.5092e-01, -2.4159e-01],
[ 1.7636e-02, 1.3744e-01, -1.0306e-01, 8.8370e-02, 7.3258e-02]]],
[[[ 2.0016e-01, 1.0956e-01, -5.9223e-02, 6.4871e-03, -2.4165e-01],
[ 5.6283e-02, 1.7276e-01, -2.2316e-01, -1.6699e-01, -7.0742e-02],
[ 2.6179e-01, -2.5102e-01, -2.0774e-01, -9.6413e-02, 3.4367e-02],
[-9.1882e-02, -2.9195e-01, -8.7432e-02, 1.0144e-01, -2.0559e-02],
[-2.5668e-01, -9.8016e-02, 1.1103e-01, -3.0233e-02, 1.1076e-01]]],
[[[ 1.0027e-03, -5.7955e-02, -2.1339e-01, -1.6729e-01, -2.0870e-01],
[ 4.2464e-02, 2.3177e-01, -6.1459e-02, -1.0905e-01, 1.7613e-02],
[-1.2282e-01, 2.1762e-01, -1.3553e-02, 2.7476e-01, 1.6703e-01],
[-5.6282e-02, 1.2731e-02, 1.0944e-01, -1.7347e-01, 4.4497e-02],
[ 5.7346e-02, -5.4657e-02, 4.8718e-02, -2.6221e-02, -2.6933e-02]]],
[[[ 6.7697e-02, 1.5692e-01, 2.7050e-01, 1.5936e-02, 1.7659e-01],
[-2.8899e-02, -1.4866e-01, 3.1838e-02, 1.0903e-01, 1.2292e-01],
[-1.3608e-01, -4.3198e-03, -9.8925e-02, -4.5599e-02, 1.3452e-01],
[-5.1435e-02, -2.3815e-01, -2.4151e-01, -4.8556e-02, 1.3825e-01],
[-1.2823e-01, 8.9324e-03, -1.5313e-01, -2.2933e-01, -3.4081e-02]]],
[[[-1.8396e-01, -6.8774e-03, -1.6675e-01, 7.1980e-03, 1.9922e-02],
[ 1.3416e-01, -1.1450e-01, -1.5277e-01, -6.5713e-02, -9.5435e-02],
[ 1.5406e-01, -9.1235e-02, -1.0880e-01, -7.1603e-02, -9.5575e-02],
[ 2.1772e-01, 8.4073e-02, -2.5264e-01, -2.1428e-01, 1.9537e-01],
[ 1.3124e-01, 7.9532e-02, -2.4044e-01, -1.5717e-01, 1.6562e-01]]],
[[[ 1.1849e-01, -5.0517e-03, -1.8900e-01, 1.8093e-02, 6.4660e-02],
[-1.5309e-01, -2.0106e-01, -8.6551e-02, 5.2692e-03, 1.5448e-01],
[-3.0727e-01, 4.9703e-02, -4.7637e-02, 2.9111e-01, -1.3173e-01],
[-8.5167e-02, -1.3540e-01, 2.9235e-01, 3.7895e-03, -9.4651e-02],
[-6.0694e-02, 9.6936e-02, 1.0533e-01, -6.1769e-02, -1.8086e-01]]]],
device='cuda:0')
2. Prepare config_list for pruning¶
[7]:
# we will prune 50% weights in `conv1`.
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d'],
'op_names': ['conv1']
}]
3. Choose a pruner and pruning¶
[8]:
# use l1filter pruner to prune the model
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
# Note that if you use a compressor that need you to pass a optimizer,
# you need a new optimizer instead of you have used above, because NNI might modify the optimizer.
# And of course this modified optimizer can not be used in finetuning.
pruner = L1FilterPruner(model, config_list)
[9]:
# we can find the `conv1` has been wrapped, the origin `conv1` changes to `conv1.module`.
# the weight of conv1 will modify by `weight * mask` in `forward()`. The initial mask is a `ones_like(weight)` tensor.
[print('op_name: {}\nop_type: {}\n'.format(name, type(module))) for name, module in model.named_modules()]
op_name:
op_type: <class '__main__.NaiveModel'>
op_name: conv1
op_type: <class 'nni.compression.pytorch.compressor.PrunerModuleWrapper'>
op_name: conv1.module
op_type: <class 'torch.nn.modules.conv.Conv2d'>
op_name: conv2
op_type: <class 'torch.nn.modules.conv.Conv2d'>
op_name: fc1
op_type: <class 'torch.nn.modules.linear.Linear'>
op_name: fc2
op_type: <class 'torch.nn.modules.linear.Linear'>
op_name: relu1
op_type: <class 'torch.nn.modules.activation.ReLU6'>
op_name: relu2
op_type: <class 'torch.nn.modules.activation.ReLU6'>
op_name: relu3
op_type: <class 'torch.nn.modules.activation.ReLU6'>
op_name: max_pool1
op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>
op_name: max_pool2
op_type: <class 'torch.nn.modules.pooling.MaxPool2d'>
[9]:
[None, None, None, None, None, None, None, None, None, None, None]
[10]:
# compress the model, the mask will be updated.
pruner.compress()
[10]:
NaiveModel(
(conv1): PrunerModuleWrapper(
(module): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
)
(conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=800, out_features=500, bias=True)
(fc2): Linear(in_features=500, out_features=10, bias=True)
(relu1): ReLU6()
(relu2): ReLU6()
(relu3): ReLU6()
(max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
[11]:
# show the mask size of `conv1`
print(model.conv1.weight_mask.size())
torch.Size([20, 1, 5, 5])
[12]:
# show the mask of `conv1`
print(model.conv1.weight_mask)
tensor([[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]],
[[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]],
[[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]]], device='cuda:0')
[13]:
# use a dummy input to apply the sparsify.
model(torch.rand(1, 1, 28, 28).to(device))
# the weights of `conv1` have been sparsified.
print(model.conv1.module.weight.data)
tensor([[[[ 1.5338e-01, -1.1766e-01, -2.6654e-01, -2.9445e-02, -1.4650e-01],
[-1.8796e-01, -2.9882e-01, 6.9725e-02, 2.1561e-01, 6.5688e-02],
[ 1.5274e-01, -9.8471e-03, 3.2303e-01, 1.3472e-03, 1.7235e-01],
[ 1.1804e-01, 2.2535e-01, -8.3370e-02, -3.4553e-02, -1.2529e-01],
[-6.6012e-02, -2.0272e-02, -1.8797e-01, -4.6882e-02, -8.3206e-02]]],
[[[-0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00],
[-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],
[ 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],
[ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],
[[[ 3.8332e-02, -1.4270e-01, -1.9585e-01, 2.2653e-01, 1.0104e-01],
[-2.7956e-03, -1.4108e-01, -1.4694e-01, -1.3525e-01, 2.6959e-01],
[ 1.9522e-01, -1.2281e-01, -1.9173e-01, -1.8910e-02, 3.1572e-03],
[-1.0580e-01, -2.5239e-02, -5.8266e-02, -6.5815e-02, 6.6433e-02],
[ 8.9601e-02, 7.1189e-02, -2.4255e-01, 1.5746e-01, -1.4708e-01]]],
[[[-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],
[-0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00]]],
[[[-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],
[-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],
[ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],
[[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],
[ 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00]]],
[[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],
[[[ 4.0772e-02, -8.2086e-02, -2.7555e-01, -3.2547e-01, -1.2226e-01],
[-5.9877e-02, 9.8567e-02, 2.5186e-01, -1.0280e-01, -2.3416e-01],
[ 8.5760e-02, 1.0896e-01, 1.4898e-01, 2.1579e-01, 8.5297e-02],
[ 5.4720e-02, -1.7226e-01, -7.2518e-02, 6.7099e-03, -1.6011e-03],
[-8.9944e-02, 1.7404e-01, -3.6985e-02, 1.8602e-01, 7.2353e-02]]],
[[[ 1.6276e-02, -9.6439e-02, -9.6085e-02, -2.4267e-01, -1.8521e-01],
[ 6.3310e-02, 1.7866e-01, 1.1694e-01, -1.4464e-01, -2.7711e-01],
[-2.4514e-02, 2.2222e-01, 2.1053e-01, -1.4271e-01, 8.7045e-02],
[-1.9207e-01, -5.4719e-02, -5.7775e-03, -1.0034e-05, -1.0923e-01],
[-2.4006e-02, 2.3780e-02, 1.8988e-01, 2.4734e-01, 4.8097e-02]]],
[[[ 1.1335e-01, -5.8451e-02, 5.2440e-02, -1.3223e-01, -2.5534e-02],
[ 9.1323e-02, -6.0707e-02, 2.3524e-01, 2.4992e-01, 8.7842e-02],
[ 2.9002e-02, 3.5379e-02, -5.9689e-02, -2.8363e-03, 1.8618e-01],
[-2.9671e-01, 8.1830e-03, 1.1076e-01, -5.4118e-02, -6.1685e-02],
[-1.7580e-01, -3.4534e-01, -3.9250e-01, -2.7569e-01, -2.6131e-01]]],
[[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00]]],
[[[ 6.3785e-02, -1.8044e-01, -1.0190e-01, -1.3588e-01, 8.5433e-02],
[ 2.0675e-01, 3.3238e-02, 9.2437e-02, 1.1799e-01, 2.1111e-01],
[-5.2138e-02, 1.5790e-01, 1.8151e-01, 8.0470e-02, 1.0131e-01],
[-4.4786e-02, 1.1771e-01, 2.1706e-02, -1.2563e-01, -2.1142e-01],
[-2.3589e-01, -2.1154e-01, -1.7890e-01, -2.7769e-01, -1.2512e-01]]],
[[[ 1.9133e-01, 2.4711e-01, 1.0413e-01, -1.9187e-01, -3.0991e-01],
[-1.2382e-01, 8.3641e-03, -5.6734e-02, 5.8376e-02, 2.2880e-02],
[-3.1734e-01, -1.0637e-02, -5.5974e-02, 1.0676e-01, -1.1080e-02],
[-2.2980e-01, 2.0486e-01, 1.0147e-01, 1.4484e-01, 5.2265e-02],
[ 7.4410e-02, 2.2806e-02, 8.5137e-02, -2.1809e-01, 3.1704e-02]]],
[[[-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],
[-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],
[[[ 7.6976e-02, -3.8627e-02, 1.2610e-01, 1.1994e-01, 2.1706e-03],
[ 7.4357e-02, 6.7929e-02, 3.1386e-02, 1.4606e-01, 2.1429e-01],
[-2.6569e-01, -4.2631e-04, -3.6654e-02, -3.0967e-02, -9.4961e-02],
[-2.0192e-01, -3.5423e-01, -2.5246e-01, -3.5092e-01, -2.4159e-01],
[ 1.7636e-02, 1.3744e-01, -1.0306e-01, 8.8370e-02, 7.3258e-02]]],
[[[ 2.0016e-01, 1.0956e-01, -5.9223e-02, 6.4871e-03, -2.4165e-01],
[ 5.6283e-02, 1.7276e-01, -2.2316e-01, -1.6699e-01, -7.0742e-02],
[ 2.6179e-01, -2.5102e-01, -2.0774e-01, -9.6413e-02, 3.4367e-02],
[-9.1882e-02, -2.9195e-01, -8.7432e-02, 1.0144e-01, -2.0559e-02],
[-2.5668e-01, -9.8016e-02, 1.1103e-01, -3.0233e-02, 1.1076e-01]]],
[[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
[ 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],
[ 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00]]],
[[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],
[[[-1.8396e-01, -6.8774e-03, -1.6675e-01, 7.1980e-03, 1.9922e-02],
[ 1.3416e-01, -1.1450e-01, -1.5277e-01, -6.5713e-02, -9.5435e-02],
[ 1.5406e-01, -9.1235e-02, -1.0880e-01, -7.1603e-02, -9.5575e-02],
[ 2.1772e-01, 8.4073e-02, -2.5264e-01, -2.1428e-01, 1.9537e-01],
[ 1.3124e-01, 7.9532e-02, -2.4044e-01, -1.5717e-01, 1.6562e-01]]],
[[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],
[-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00],
[-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00]]]],
device='cuda:0')
[14]:
# export the sparsified model state to './pruned_naive_mnist_l1filter.pth'.
# export the mask to './mask_naive_mnist_l1filter.pth'.
pruner.export_model(model_path='pruned_naive_mnist_l1filter.pth', mask_path='mask_naive_mnist_l1filter.pth')
[2021-07-26 22:26:05] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to pruned_naive_mnist_l1filter.pth
[2021-07-26 22:26:05] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to mask_naive_mnist_l1filter.pth
4. Speed Up¶
[15]:
# If you use a wrapped model, don't forget to unwrap it.
pruner._unwrap_model()
# the model has been unwrapped.
print(model)
NaiveModel(
(conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=800, out_features=500, bias=True)
(fc2): Linear(in_features=500, out_features=10, bias=True)
(relu1): ReLU6()
(relu2): ReLU6()
(relu3): ReLU6()
(max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
[16]:
from nni.compression.pytorch import ModelSpeedup
m_speedup = ModelSpeedup(model, dummy_input=torch.rand(10, 1, 28, 28).to(device), masks_file='mask_naive_mnist_l1filter.pth')
m_speedup.speedup_model()
<ipython-input-1-0f2a9eb92f42>:22: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
x = x.view(-1, x.size()[1:].numel())
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) start to speed up the model
[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) {'conv1': 1, 'conv2': 1}
[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) dim0 sparsity: 0.500000
[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) dim1 sparsity: 0.000000
[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) Dectected conv prune dim" 0
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) infer module masks...
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for conv1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for max_pool1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for conv2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for max_pool2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::view.9
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.jit_translate/MainThread) View Module output size: [-1, 800]
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for fc1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu3
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for fc2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::log_softmax.10
[2021-07-26 22:26:18] ERROR (nni.compression.pytorch.speedup.jit_translate/MainThread) aten::log_softmax is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for .aten::log_softmax.10
[2021-07-26 22:26:18] WARNING (nni.compression.pytorch.speedup.compressor/MainThread) Note: .aten::log_softmax.10 does not have corresponding mask inference object
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for fc2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the fc2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu3
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu3
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for fc1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the fc1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for .aten::view.9
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::view.9
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for max_pool2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the max_pool2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for conv2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the conv2
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for max_pool1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the max_pool1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for conv1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the conv1
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) resolve the mask conflict
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace compressed modules...
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: conv1, op_type: Conv2d)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu1, op_type: ReLU6)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: max_pool1, op_type: MaxPool2d)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: conv2, op_type: Conv2d)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu2, op_type: ReLU6)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: max_pool2, op_type: MaxPool2d)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::view.9, op_type: aten::view) which is func type
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: fc1, op_type: Linear)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace linear with new in_features: 800, out_features: 500
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu3, op_type: ReLU6)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: fc2, op_type: Linear)
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace linear with new in_features: 500, out_features: 10
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::log_softmax.10, op_type: aten::log_softmax) which is func type
[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) speedup done
[17]:
# the `conv1` has been replace from `Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))` to `Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))`
# and the following layer `conv2` has also changed because the input channel of `conv2` should aware the output channel of `conv1`.
print(model)
NaiveModel(
(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(10, 50, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=800, out_features=500, bias=True)
(fc2): Linear(in_features=500, out_features=10, bias=True)
(relu1): ReLU6()
(relu2): ReLU6()
(relu3): ReLU6()
(max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
[18]:
# finetune the model to recover the accuracy.
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(0, 1):
trainer(model, optimizer, criterion, epoch)
evaluator(model)
Train Epoch: 0 [0/60000 (0%)] Loss: 0.306930
Train Epoch: 0 [6400/60000 (11%)] Loss: 0.045807
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.049293
Train Epoch: 0 [19200/60000 (32%)] Loss: 0.031464
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.005392
Train Epoch: 0 [32000/60000 (53%)] Loss: 0.005652
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.040619
Train Epoch: 0 [44800/60000 (75%)] Loss: 0.016515
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.092886
Train Epoch: 0 [57600/60000 (96%)] Loss: 0.041380
Test set: Average loss: 0.0257, Accuracy: 9917/10000 (99%)
5. Prepare config_list for quantization¶
[19]:
config_list = [{
'quant_types': ['weight', 'input'],
'quant_bits': {'weight': 8, 'input': 8},
'op_names': ['conv1', 'conv2']
}]
6. Choose a quantizer and quantizing¶
[20]:
from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer
quantizer = QAT_Quantizer(model, config_list, optimizer)
quantizer.compress()
[20]:
NaiveModel(
(conv1): QuantizerModuleWrapper(
(module): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
)
(conv2): QuantizerModuleWrapper(
(module): Conv2d(10, 50, kernel_size=(5, 5), stride=(1, 1))
)
(fc1): Linear(in_features=800, out_features=500, bias=True)
(fc2): Linear(in_features=500, out_features=10, bias=True)
(relu1): ReLU6()
(relu2): ReLU6()
(relu3): ReLU6()
(max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
[21]:
# finetune the model for calibration.
for epoch in range(0, 1):
trainer(model, optimizer, criterion, epoch)
evaluator(model)
Train Epoch: 0 [0/60000 (0%)] Loss: 0.004960
Train Epoch: 0 [6400/60000 (11%)] Loss: 0.036269
Train Epoch: 0 [12800/60000 (21%)] Loss: 0.018744
Train Epoch: 0 [19200/60000 (32%)] Loss: 0.021916
Train Epoch: 0 [25600/60000 (43%)] Loss: 0.003095
Train Epoch: 0 [32000/60000 (53%)] Loss: 0.003947
Train Epoch: 0 [38400/60000 (64%)] Loss: 0.032094
Train Epoch: 0 [44800/60000 (75%)] Loss: 0.017358
Train Epoch: 0 [51200/60000 (85%)] Loss: 0.083886
Train Epoch: 0 [57600/60000 (96%)] Loss: 0.040433
Test set: Average loss: 0.0247, Accuracy: 9917/10000 (99%)
[22]:
# export the sparsified model state to './quantized_naive_mnist_l1filter.pth'.
# export the calibration config to './calibration_naive_mnist_l1filter.pth'.
quantizer.export_model(model_path='quantized_naive_mnist_l1filter.pth', calibration_path='calibration_naive_mnist_l1filter.pth')
[2021-07-26 22:34:41] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to quantized_naive_mnist_l1filter.pth
[2021-07-26 22:34:41] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to calibration_naive_mnist_l1filter.pth
[22]:
{'conv1': {'weight_bit': 8,
'tracked_min_input': -0.42417848110198975,
'tracked_max_input': 2.8212687969207764},
'conv2': {'weight_bit': 8,
'tracked_min_input': 0.0,
'tracked_max_input': 4.246923446655273}}
7. Speed Up¶
[ ]:
# speed up with tensorRT
engine = ModelSpeedupTensorRT(model, (32, 1, 28, 28), config=calibration_config, batchsize=32)
engine.compress()