Pruning Bert on Task MNLI

Workable Pruning Process

Here we show an effective transformer pruning process that NNI team has tried, and users can use NNI to discover better processes.

The entire pruning process can be divided into the following steps:

  1. Finetune the pre-trained model on the downstream task. From our experience, the final performance of pruning on the finetuned model is better than pruning directly on the pre-trained model. At the same time, the finetuned model obtained in this step will also be used as the teacher model for the following distillation training.

  2. Pruning the attention layer at first. Here we apply block-sparse on attention layer weight, and directly prune the head (condense the weight) if the head was fully masked. If the head was partially masked, we will not prune it and recover its weight.

  3. Retrain the head-pruned model with distillation. Recover the model precision before pruning FFN layer.

  4. Pruning the FFN layer. Here we apply the output channels pruning on the 1st FFN layer, and the 2nd FFN layer input channels will be pruned due to the pruning of 1st layer output channels.

  5. Retrain the final pruned model with distillation.

During the process of pruning transformer, we gained some of the following experiences:

  • We using Movement Pruner in step 2 and Taylor FO Weight Pruner in step 4. Movement Pruner has good performance on attention layers, and Taylor FO Weight Pruner method has good performance on FFN layers. These two pruners are all some kinds of gradient-based pruning algorithms, we also try weight-based pruning algorithms like L1 Norm Pruner, but it doesn't seem to work well in this scenario.

  • Distillation is a good way to recover model precision. In terms of results, usually 1~2% improvement in accuracy can be achieved when we prune bert on mnli task.

  • It is necessary to gradually increase the sparsity rather than reaching a very high sparsity all at once.

Experiment

The complete pruning process will take about 8 hours on one A100.

Preparation

This section is mainly to get a finetuned model on the downstream task. If you are familiar with how to finetune Bert on GLUE dataset, you can skip this section.

备注

Please set dev_mode to False to run this tutorial. Here dev_mode is True by default is for generating documents.

dev_mode = True

Some basic setting.

from pathlib import Path
from typing import Callable, Dict

pretrained_model_name_or_path = 'bert-base-uncased'
task_name = 'mnli'
experiment_id = 'pruning_bert_mnli'

# heads_num and layers_num should align with pretrained_model_name_or_path
heads_num = 12
layers_num = 12

# used to save the experiment log
log_dir = Path(f'./pruning_log/{pretrained_model_name_or_path}/{task_name}/{experiment_id}')
log_dir.mkdir(parents=True, exist_ok=True)

# used to save the finetuned model and share between different experiemnts with same pretrained_model_name_or_path and task_name
model_dir = Path(f'./models/{pretrained_model_name_or_path}/{task_name}')
model_dir.mkdir(parents=True, exist_ok=True)

# used to save GLUE data
data_dir = Path(f'./data')
data_dir.mkdir(parents=True, exist_ok=True)

# set seed
from transformers import set_seed
set_seed(1024)

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Create dataloaders.

from torch.utils.data import DataLoader

from datasets import load_dataset
from transformers import BertTokenizerFast, DataCollatorWithPadding

task_to_keys = {
    'cola': ('sentence', None),
    'mnli': ('premise', 'hypothesis'),
    'mrpc': ('sentence1', 'sentence2'),
    'qnli': ('question', 'sentence'),
    'qqp': ('question1', 'question2'),
    'rte': ('sentence1', 'sentence2'),
    'sst2': ('sentence', None),
    'stsb': ('sentence1', 'sentence2'),
    'wnli': ('sentence1', 'sentence2'),
}

def prepare_dataloaders(cache_dir=data_dir, train_batch_size=32, eval_batch_size=32):
    tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path)
    sentence1_key, sentence2_key = task_to_keys[task_name]
    data_collator = DataCollatorWithPadding(tokenizer)

    # used to preprocess the raw data
    def preprocess_function(examples):
        # Tokenize the texts
        args = (
            (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
        )
        result = tokenizer(*args, padding=False, max_length=128, truncation=True)

        if 'label' in examples:
            # In all cases, rename the column to labels because the model will expect that.
            result['labels'] = examples['label']
        return result

    raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
    for key in list(raw_datasets.keys()):
        if 'test' in key:
            raw_datasets.pop(key)

    processed_datasets = raw_datasets.map(preprocess_function, batched=True,
                                          remove_columns=raw_datasets['train'].column_names)

    train_dataset = processed_datasets['train']
    if task_name == 'mnli':
        validation_datasets = {
            'validation_matched': processed_datasets['validation_matched'],
            'validation_mismatched': processed_datasets['validation_mismatched']
        }
    else:
        validation_datasets = {
            'validation': processed_datasets['validation']
        }

    train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=data_collator, batch_size=train_batch_size)
    validation_dataloaders = {
        val_name: DataLoader(val_dataset, collate_fn=data_collator, batch_size=eval_batch_size) \
            for val_name, val_dataset in validation_datasets.items()
    }

    return train_dataloader, validation_dataloaders


train_dataloader, validation_dataloaders = prepare_dataloaders()
Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]
Downloading: 100%|##########| 28.0/28.0 [00:00<00:00, 63.7kB/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]
Downloading:  18%|#8        | 42.0k/232k [00:00<00:00, 287kB/s]
Downloading:  46%|####6     | 108k/232k [00:00<00:00, 381kB/s]
Downloading: 100%|##########| 232k/232k [00:00<00:00, 785kB/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]
Downloading:   8%|7         | 36.9k/466k [00:00<00:01, 258kB/s]
Downloading:  29%|##9       | 135k/466k [00:00<00:00, 577kB/s]
Downloading:  75%|#######4  | 348k/466k [00:00<00:00, 1.19MB/s]
Downloading: 100%|##########| 466k/466k [00:00<00:00, 1.29MB/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]
Downloading: 100%|##########| 570/570 [00:00<00:00, 1.23MB/s]

Downloading builder script:   0%|          | 0.00/28.8k [00:00<?, ?B/s]
Downloading builder script: 100%|##########| 28.8k/28.8k [00:00<00:00, 392kB/s]

Downloading metadata:   0%|          | 0.00/28.7k [00:00<?, ?B/s]
Downloading metadata: 100%|##########| 28.7k/28.7k [00:00<00:00, 389kB/s]
Downloading and preparing dataset glue/mnli (download: 298.29 MiB, generated: 78.65 MiB, post-processed: Unknown size, total: 376.95 MiB) to /home/nishang/nni/examples/tutorials/data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...

Downloading data:   0%|          | 0.00/313M [00:00<?, ?B/s]
Downloading data:   0%|          | 718k/313M [00:00<00:45, 6.84MB/s]
Downloading data:   1%|          | 2.51M/313M [00:00<00:24, 12.8MB/s]
Downloading data:   1%|1         | 4.16M/313M [00:00<00:21, 14.3MB/s]
Downloading data:   2%|1         | 5.70M/313M [00:00<00:21, 14.6MB/s]
Downloading data:   2%|2         | 7.33M/313M [00:00<00:20, 15.2MB/s]
Downloading data:   3%|2         | 8.85M/313M [00:00<00:20, 15.1MB/s]
Downloading data:   3%|3         | 10.5M/313M [00:00<00:19, 15.3MB/s]
Downloading data:   4%|3         | 12.1M/313M [00:00<00:19, 15.5MB/s]
Downloading data:   4%|4         | 13.7M/313M [00:00<00:19, 15.6MB/s]
Downloading data:   5%|4         | 15.3M/313M [00:01<00:18, 15.7MB/s]
Downloading data:   6%|5         | 18.7M/313M [00:01<00:13, 21.5MB/s]
Downloading data:   8%|7         | 24.6M/313M [00:01<00:08, 32.8MB/s]
Downloading data:  11%|#1        | 34.6M/313M [00:01<00:05, 53.1MB/s]
Downloading data:  15%|#4        | 45.7M/313M [00:01<00:03, 70.5MB/s]
Downloading data:  17%|#6        | 52.8M/313M [00:01<00:04, 53.5MB/s]
Downloading data:  20%|##        | 62.9M/313M [00:01<00:04, 51.8MB/s]
Downloading data:  23%|##3       | 72.8M/313M [00:01<00:03, 62.1MB/s]
Downloading data:  25%|##5       | 79.7M/313M [00:02<00:03, 60.2MB/s]
Downloading data:  29%|##9       | 91.3M/313M [00:02<00:03, 73.7MB/s]
Downloading data:  32%|###1      | 99.4M/313M [00:02<00:03, 69.3MB/s]
Downloading data:  35%|###5      | 110M/313M [00:02<00:02, 76.0MB/s]
Downloading data:  39%|###8      | 122M/313M [00:02<00:02, 87.0MB/s]
Downloading data:  42%|####1     | 131M/313M [00:02<00:02, 79.9MB/s]
Downloading data:  45%|####5     | 142M/313M [00:02<00:02, 73.7MB/s]
Downloading data:  49%|####9     | 153M/313M [00:02<00:01, 84.7MB/s]
Downloading data:  52%|#####1    | 162M/313M [00:03<00:01, 80.0MB/s]
Downloading data:  55%|#####4    | 171M/313M [00:03<00:03, 36.2MB/s]
Downloading data:  57%|#####6    | 177M/313M [00:03<00:04, 33.4MB/s]
Downloading data:  58%|#####8    | 182M/313M [00:03<00:03, 36.0MB/s]
Downloading data:  60%|######    | 189M/313M [00:04<00:03, 39.3MB/s]
Downloading data:  64%|######4   | 200M/313M [00:04<00:02, 53.6MB/s]
Downloading data:  66%|######6   | 207M/313M [00:04<00:02, 43.9MB/s]
Downloading data:  70%|#######   | 219M/313M [00:04<00:01, 58.0MB/s]
Downloading data:  73%|#######2  | 227M/313M [00:04<00:01, 63.3MB/s]
Downloading data:  75%|#######5  | 236M/313M [00:04<00:01, 55.5MB/s]
Downloading data:  78%|#######8  | 245M/313M [00:04<00:01, 63.9MB/s]
Downloading data:  82%|########1 | 256M/313M [00:05<00:00, 73.7MB/s]
Downloading data:  85%|########4 | 264M/313M [00:05<00:00, 63.7MB/s]
Downloading data:  87%|########6 | 272M/313M [00:05<00:00, 47.0MB/s]
Downloading data:  90%|######### | 283M/313M [00:05<00:00, 57.2MB/s]
Downloading data:  93%|#########2| 290M/313M [00:05<00:00, 50.2MB/s]
Downloading data:  96%|#########5| 299M/313M [00:05<00:00, 54.7MB/s]
Downloading data:  99%|#########9| 310M/313M [00:06<00:00, 67.6MB/s]
Downloading data: 100%|##########| 313M/313M [00:06<00:00, 51.7MB/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]
Generating train split:   1%|          | 2477/392702 [00:00<00:15, 24761.40 examples/s]
Generating train split:   1%|1         | 4954/392702 [00:00<00:16, 24160.14 examples/s]
Generating train split:   2%|1         | 7372/392702 [00:00<00:16, 24024.43 examples/s]
Generating train split:   3%|2         | 9844/392702 [00:00<00:15, 24293.64 examples/s]
Generating train split:   3%|3         | 12274/392702 [00:01<00:50, 7567.28 examples/s]
Generating train split:   4%|3         | 14708/392702 [00:01<00:38, 9870.96 examples/s]
Generating train split:   4%|4         | 17144/392702 [00:01<00:30, 12247.55 examples/s]
Generating train split:   5%|5         | 19712/392702 [00:01<00:25, 14796.89 examples/s]
Generating train split:   6%|5         | 21979/392702 [00:01<00:22, 16474.78 examples/s]
Generating train split:   6%|6         | 24498/392702 [00:01<00:19, 18508.45 examples/s]
Generating train split:   7%|6         | 27053/392702 [00:01<00:18, 20268.73 examples/s]
Generating train split:   8%|7         | 29646/392702 [00:01<00:16, 21758.66 examples/s]
Generating train split:   8%|8         | 32107/392702 [00:01<00:16, 22502.23 examples/s]
Generating train split:   9%|8         | 34705/392702 [00:02<00:15, 23476.37 examples/s]
Generating train split:  10%|9         | 37321/392702 [00:02<00:14, 24241.84 examples/s]
Generating train split:  10%|#         | 39855/392702 [00:02<00:14, 24409.06 examples/s]
Generating train split:  11%|#         | 42373/392702 [00:02<00:14, 24029.97 examples/s]
Generating train split:  11%|#1        | 44965/392702 [00:02<00:14, 24574.30 examples/s]
Generating train split:  12%|#2        | 47548/392702 [00:02<00:13, 24939.00 examples/s]
Generating train split:  13%|#2        | 50072/392702 [00:02<00:14, 24228.50 examples/s]
Generating train split:  13%|#3        | 52657/392702 [00:02<00:13, 24697.19 examples/s]
Generating train split:  14%|#4        | 55277/392702 [00:02<00:13, 25135.96 examples/s]
Generating train split:  15%|#4        | 57848/392702 [00:02<00:13, 25304.01 examples/s]
Generating train split:  15%|#5        | 60389/392702 [00:03<00:13, 24185.74 examples/s]
Generating train split:  16%|#5        | 62825/392702 [00:03<00:13, 24064.60 examples/s]
Generating train split:  17%|#6        | 65243/392702 [00:03<00:13, 23934.12 examples/s]
Generating train split:  17%|#7        | 67817/392702 [00:03<00:13, 24458.33 examples/s]
Generating train split:  18%|#7        | 70270/392702 [00:03<00:13, 24091.01 examples/s]
Generating train split:  19%|#8        | 72865/392702 [00:03<00:12, 24633.02 examples/s]
Generating train split:  19%|#9        | 75334/392702 [00:03<00:12, 24642.55 examples/s]
Generating train split:  20%|#9        | 77802/392702 [00:03<00:13, 24197.32 examples/s]
Generating train split:  20%|##        | 80226/392702 [00:03<00:13, 23358.01 examples/s]
Generating train split:  21%|##1       | 82583/392702 [00:03<00:13, 23418.48 examples/s]
Generating train split:  22%|##1       | 84931/392702 [00:04<00:13, 23434.32 examples/s]
Generating train split:  22%|##2       | 87279/392702 [00:04<00:13, 23401.34 examples/s]
Generating train split:  23%|##2       | 89622/392702 [00:04<00:13, 23283.00 examples/s]
Generating train split:  23%|##3       | 91953/392702 [00:04<00:13, 22655.48 examples/s]
Generating train split:  24%|##4       | 94298/392702 [00:04<00:13, 22883.81 examples/s]
Generating train split:  25%|##4       | 96634/392702 [00:04<00:12, 23022.86 examples/s]
Generating train split:  25%|##5       | 98984/392702 [00:04<00:12, 23162.97 examples/s]
Generating train split:  26%|##5       | 101303/392702 [00:04<00:12, 22604.72 examples/s]
Generating train split:  26%|##6       | 103646/392702 [00:04<00:12, 22844.75 examples/s]
Generating train split:  27%|##6       | 105985/392702 [00:05<00:12, 23004.22 examples/s]
Generating train split:  28%|##7       | 108315/392702 [00:05<00:12, 23091.62 examples/s]
Generating train split:  28%|##8       | 110627/392702 [00:05<00:12, 22589.59 examples/s]
Generating train split:  29%|##8       | 113002/392702 [00:05<00:12, 22928.68 examples/s]
Generating train split:  29%|##9       | 115460/392702 [00:05<00:11, 23415.53 examples/s]
Generating train split:  30%|##9       | 117805/392702 [00:05<00:16, 17139.07 examples/s]
Generating train split:  31%|###       | 120089/392702 [00:05<00:14, 18488.39 examples/s]
Generating train split:  31%|###1      | 122676/392702 [00:05<00:13, 20353.59 examples/s]
Generating train split:  32%|###1      | 125291/392702 [00:05<00:12, 21888.35 examples/s]
Generating train split:  33%|###2      | 127897/392702 [00:06<00:11, 23034.40 examples/s]
Generating train split:  33%|###3      | 130313/392702 [00:06<00:11, 23167.64 examples/s]
Generating train split:  34%|###3      | 132715/392702 [00:06<00:11, 23410.85 examples/s]
Generating train split:  34%|###4      | 135147/392702 [00:06<00:10, 23672.39 examples/s]
Generating train split:  35%|###5      | 137722/392702 [00:06<00:10, 24280.14 examples/s]
Generating train split:  36%|###5      | 140180/392702 [00:06<00:10, 23695.95 examples/s]
Generating train split:  36%|###6      | 142574/392702 [00:06<00:10, 23714.07 examples/s]
Generating train split:  37%|###6      | 145123/392702 [00:06<00:10, 24232.43 examples/s]
Generating train split:  38%|###7      | 147723/392702 [00:06<00:09, 24751.28 examples/s]
Generating train split:  38%|###8      | 150208/392702 [00:06<00:09, 24489.24 examples/s]
Generating train split:  39%|###8      | 152799/392702 [00:07<00:09, 24906.19 examples/s]
Generating train split:  40%|###9      | 155296/392702 [00:07<00:09, 24767.77 examples/s]
Generating train split:  40%|####      | 157777/392702 [00:07<00:09, 24685.21 examples/s]
Generating train split:  41%|####      | 160249/392702 [00:07<00:09, 24263.44 examples/s]
Generating train split:  41%|####1     | 162781/392702 [00:07<00:09, 24572.49 examples/s]
Generating train split:  42%|####2     | 165242/392702 [00:07<00:09, 24402.94 examples/s]
Generating train split:  43%|####2     | 167830/392702 [00:07<00:09, 24835.71 examples/s]
Generating train split:  43%|####3     | 170316/392702 [00:07<00:09, 24394.30 examples/s]
Generating train split:  44%|####3     | 172781/392702 [00:07<00:08, 24467.75 examples/s]
Generating train split:  45%|####4     | 175230/392702 [00:07<00:08, 24331.75 examples/s]
Generating train split:  45%|####5     | 177665/392702 [00:08<00:09, 23799.25 examples/s]
Generating train split:  46%|####5     | 180048/392702 [00:08<00:09, 22921.96 examples/s]
Generating train split:  46%|####6     | 182489/392702 [00:08<00:09, 23346.72 examples/s]
Generating train split:  47%|####7     | 185052/392702 [00:08<00:08, 24008.98 examples/s]
Generating train split:  48%|####7     | 187494/392702 [00:08<00:08, 24128.72 examples/s]
Generating train split:  48%|####8     | 189913/392702 [00:08<00:08, 24066.62 examples/s]
Generating train split:  49%|####8     | 192324/392702 [00:08<00:08, 23775.61 examples/s]
Generating train split:  50%|####9     | 194894/392702 [00:08<00:08, 24339.27 examples/s]
Generating train split:  50%|#####     | 197332/392702 [00:08<00:08, 24140.19 examples/s]
Generating train split:  51%|#####     | 199858/392702 [00:09<00:07, 24469.54 examples/s]
Generating train split:  52%|#####1    | 202308/392702 [00:09<00:07, 24381.38 examples/s]
Generating train split:  52%|#####2    | 204748/392702 [00:09<00:07, 24117.89 examples/s]
Generating train split:  53%|#####2    | 207162/392702 [00:09<00:07, 23966.83 examples/s]
Generating train split:  53%|#####3    | 209621/392702 [00:09<00:07, 24148.88 examples/s]
Generating train split:  54%|#####3    | 212037/392702 [00:09<00:07, 23899.57 examples/s]
Generating train split:  55%|#####4    | 214457/392702 [00:09<00:07, 23987.91 examples/s]
Generating train split:  55%|#####5    | 216874/392702 [00:09<00:07, 24040.04 examples/s]
Generating train split:  56%|#####5    | 219436/392702 [00:09<00:07, 24509.83 examples/s]
Generating train split:  57%|#####6    | 221888/392702 [00:09<00:07, 23752.14 examples/s]
Generating train split:  57%|#####7    | 224269/392702 [00:10<00:07, 23602.20 examples/s]
Generating train split:  58%|#####7    | 226633/392702 [00:10<00:07, 23490.68 examples/s]
Generating train split:  58%|#####8    | 228985/392702 [00:10<00:06, 23392.40 examples/s]
Generating train split:  59%|#####8    | 231326/392702 [00:10<00:07, 22886.41 examples/s]
Generating train split:  59%|#####9    | 233635/392702 [00:10<00:06, 22944.67 examples/s]
Generating train split:  60%|######    | 236030/392702 [00:10<00:06, 23239.24 examples/s]
Generating train split:  61%|######    | 238356/392702 [00:10<00:06, 23207.14 examples/s]
Generating train split:  61%|######1   | 240708/392702 [00:10<00:06, 23298.05 examples/s]
Generating train split:  62%|######1   | 243286/392702 [00:10<00:06, 24036.66 examples/s]
Generating train split:  63%|######2   | 245751/392702 [00:10<00:06, 24217.44 examples/s]
Generating train split:  63%|######3   | 248174/392702 [00:11<00:08, 17662.89 examples/s]
Generating train split:  64%|######3   | 250269/392702 [00:11<00:07, 18442.01 examples/s]
Generating train split:  64%|######4   | 252697/392702 [00:11<00:07, 19920.24 examples/s]
Generating train split:  65%|######4   | 255079/392702 [00:11<00:06, 20954.25 examples/s]
Generating train split:  66%|######5   | 257426/392702 [00:11<00:06, 21643.87 examples/s]
Generating train split:  66%|######6   | 259761/392702 [00:11<00:06, 22123.64 examples/s]
Generating train split:  67%|######6   | 262045/392702 [00:11<00:05, 21896.71 examples/s]
Generating train split:  67%|######7   | 264550/392702 [00:11<00:05, 22799.37 examples/s]
Generating train split:  68%|######8   | 267090/392702 [00:11<00:05, 23552.75 examples/s]
Generating train split:  69%|######8   | 269654/392702 [00:12<00:05, 24163.06 examples/s]
Generating train split:  69%|######9   | 272093/392702 [00:12<00:05, 23589.17 examples/s]
Generating train split:  70%|######9   | 274470/392702 [00:12<00:05, 23591.03 examples/s]
Generating train split:  70%|#######   | 276842/392702 [00:12<00:04, 23480.89 examples/s]
Generating train split:  71%|#######1  | 279224/392702 [00:12<00:04, 23577.68 examples/s]
Generating train split:  72%|#######1  | 281589/392702 [00:12<00:04, 23387.85 examples/s]
Generating train split:  72%|#######2  | 283938/392702 [00:12<00:04, 23417.16 examples/s]
Generating train split:  73%|#######2  | 286283/392702 [00:12<00:04, 23389.83 examples/s]
Generating train split:  73%|#######3  | 288625/392702 [00:12<00:04, 23377.36 examples/s]
Generating train split:  74%|#######4  | 290965/392702 [00:12<00:04, 22969.87 examples/s]
Generating train split:  75%|#######4  | 293322/392702 [00:13<00:04, 23143.13 examples/s]
Generating train split:  75%|#######5  | 295656/392702 [00:13<00:04, 23199.61 examples/s]
Generating train split:  76%|#######5  | 298023/392702 [00:13<00:04, 23337.63 examples/s]
Generating train split:  76%|#######6  | 300358/392702 [00:13<00:04, 22983.28 examples/s]
Generating train split:  77%|#######7  | 302826/392702 [00:13<00:03, 23483.05 examples/s]
Generating train split:  78%|#######7  | 305377/392702 [00:13<00:03, 24082.07 examples/s]
Generating train split:  78%|#######8  | 307788/392702 [00:13<00:03, 23853.98 examples/s]
Generating train split:  79%|#######8  | 310176/392702 [00:13<00:03, 23184.23 examples/s]
Generating train split:  80%|#######9  | 312508/392702 [00:13<00:03, 23221.02 examples/s]
Generating train split:  80%|########  | 314905/392702 [00:14<00:03, 23439.87 examples/s]
Generating train split:  81%|########  | 317285/392702 [00:14<00:03, 23544.00 examples/s]
Generating train split:  81%|########1 | 319642/392702 [00:14<00:03, 23497.37 examples/s]
Generating train split:  82%|########1 | 321994/392702 [00:15<00:11, 5903.78 examples/s]
Generating train split:  83%|########2 | 324345/392702 [00:15<00:08, 7601.13 examples/s]
Generating train split:  83%|########3 | 326703/392702 [00:15<00:06, 9535.38 examples/s]
Generating train split:  84%|########3 | 329059/392702 [00:15<00:05, 11602.82 examples/s]
Generating train split:  84%|########4 | 331202/392702 [00:15<00:04, 13311.65 examples/s]
Generating train split:  85%|########4 | 333529/392702 [00:15<00:03, 15292.81 examples/s]
Generating train split:  86%|########5 | 335894/392702 [00:15<00:03, 17142.48 examples/s]
Generating train split:  86%|########6 | 338226/392702 [00:16<00:02, 18624.70 examples/s]
Generating train split:  87%|########6 | 340497/392702 [00:16<00:02, 19478.41 examples/s]
Generating train split:  87%|########7 | 342802/392702 [00:16<00:02, 20425.87 examples/s]
Generating train split:  88%|########7 | 345157/392702 [00:16<00:02, 21283.19 examples/s]
Generating train split:  88%|########8 | 347515/392702 [00:16<00:02, 21929.47 examples/s]
Generating train split:  89%|########9 | 349853/392702 [00:16<00:01, 22344.47 examples/s]
Generating train split:  90%|########9 | 352172/392702 [00:16<00:01, 22357.56 examples/s]
Generating train split:  90%|######### | 354549/392702 [00:16<00:01, 22766.94 examples/s]
Generating train split:  91%|######### | 356889/392702 [00:16<00:01, 22951.02 examples/s]
Generating train split:  92%|#########1| 359337/392702 [00:16<00:01, 23400.49 examples/s]
Generating train split:  92%|#########2| 361699/392702 [00:17<00:01, 22976.75 examples/s]
Generating train split:  93%|#########2| 364075/392702 [00:17<00:01, 23204.17 examples/s]
Generating train split:  93%|#########3| 366408/392702 [00:17<00:01, 23188.27 examples/s]
Generating train split:  94%|#########3| 368766/392702 [00:17<00:01, 23304.02 examples/s]
Generating train split:  94%|#########4| 371103/392702 [00:17<00:00, 22942.69 examples/s]
Generating train split:  95%|#########5| 373478/392702 [00:17<00:00, 23180.45 examples/s]
Generating train split:  96%|#########5| 375928/392702 [00:17<00:00, 23570.64 examples/s]
Generating train split:  96%|#########6| 378348/392702 [00:17<00:00, 23757.02 examples/s]
Generating train split:  97%|#########6| 380727/392702 [00:17<00:00, 23045.68 examples/s]
Generating train split:  98%|#########7| 383049/392702 [00:17<00:00, 23094.17 examples/s]
Generating train split:  98%|#########8| 385363/392702 [00:18<00:00, 16757.76 examples/s]
Generating train split:  99%|#########8| 387669/392702 [00:18<00:00, 18230.01 examples/s]
Generating train split:  99%|#########9| 390000/392702 [00:18<00:00, 19098.87 examples/s]
Generating train split: 100%|#########9| 392294/392702 [00:18<00:00, 20094.59 examples/s]


Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]
Generating validation_matched split:  24%|##3       | 2319/9815 [00:00<00:00, 23183.50 examples/s]
Generating validation_matched split:  47%|####7     | 4638/9815 [00:00<00:00, 22893.71 examples/s]
Generating validation_matched split:  71%|#######   | 6928/9815 [00:00<00:00, 22861.92 examples/s]
Generating validation_matched split:  94%|#########3| 9215/9815 [00:00<00:00, 22774.60 examples/s]


Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]
Generating validation_mismatched split:  23%|##2       | 2236/9832 [00:00<00:00, 22352.19 examples/s]
Generating validation_mismatched split:  46%|####5     | 4498/9832 [00:00<00:00, 22504.63 examples/s]
Generating validation_mismatched split:  69%|######8   | 6752/9832 [00:00<00:00, 22519.91 examples/s]
Generating validation_mismatched split:  92%|#########1| 9012/9832 [00:00<00:00, 22548.18 examples/s]


Generating test_matched split:   0%|          | 0/9796 [00:00<?, ? examples/s]
Generating test_matched split:  25%|##5       | 2452/9796 [00:00<00:00, 24512.37 examples/s]
Generating test_matched split:  50%|#####     | 4938/9796 [00:00<00:00, 24712.62 examples/s]
Generating test_matched split:  76%|#######5  | 7422/9796 [00:00<00:00, 24770.46 examples/s]


Generating test_mismatched split:   0%|          | 0/9847 [00:00<?, ? examples/s]
Generating test_mismatched split:  25%|##4       | 2439/9847 [00:00<00:00, 24382.52 examples/s]
Generating test_mismatched split:  51%|#####     | 4974/9847 [00:00<00:00, 24946.50 examples/s]
Generating test_mismatched split:  76%|#######5  | 7469/9847 [00:00<00:00, 24906.22 examples/s]

Dataset glue downloaded and prepared to /home/nishang/nni/examples/tutorials/data/glue/mnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.

  0%|          | 0/5 [00:00<?, ?it/s]
 20%|##        | 1/5 [00:05<00:23,  5.96s/it]
100%|##########| 5/5 [00:05<00:00,  1.19s/it]

  0%|          | 0/393 [00:00<?, ?ba/s]
  1%|1         | 4/393 [00:00<00:10, 37.37ba/s]
  2%|2         | 9/393 [00:00<00:09, 40.03ba/s]
  4%|3         | 14/393 [00:00<00:09, 41.19ba/s]
  5%|4         | 19/393 [00:00<00:12, 29.13ba/s]
  6%|6         | 24/393 [00:00<00:11, 32.72ba/s]
  7%|7         | 29/393 [00:00<00:10, 35.53ba/s]
  9%|8         | 34/393 [00:00<00:09, 37.13ba/s]
 10%|9         | 39/393 [00:01<00:09, 38.81ba/s]
 11%|#1        | 44/393 [00:01<00:08, 40.05ba/s]
 12%|#2        | 49/393 [00:01<00:11, 30.69ba/s]
 14%|#3        | 54/393 [00:01<00:10, 33.40ba/s]
 15%|#5        | 59/393 [00:01<00:09, 35.37ba/s]
 16%|#6        | 64/393 [00:01<00:08, 37.19ba/s]
 18%|#7        | 69/393 [00:01<00:08, 38.21ba/s]
 19%|#8        | 74/393 [00:02<00:08, 39.25ba/s]
 20%|##        | 79/393 [00:02<00:10, 30.22ba/s]
 21%|##1       | 84/393 [00:02<00:09, 33.09ba/s]
 23%|##2       | 89/393 [00:02<00:08, 35.13ba/s]
 24%|##3       | 94/393 [00:02<00:08, 36.86ba/s]
 25%|##5       | 99/393 [00:02<00:07, 38.28ba/s]
 26%|##6       | 103/393 [00:02<00:09, 29.07ba/s]
 27%|##7       | 108/393 [00:03<00:08, 32.00ba/s]
 29%|##8       | 113/393 [00:03<00:08, 34.45ba/s]
 30%|###       | 118/393 [00:03<00:07, 36.01ba/s]
 31%|###1      | 123/393 [00:03<00:07, 37.27ba/s]
 32%|###2      | 127/393 [00:03<00:09, 28.64ba/s]
 34%|###3      | 132/393 [00:03<00:08, 32.03ba/s]
 35%|###4      | 137/393 [00:03<00:07, 34.86ba/s]
 36%|###6      | 142/393 [00:04<00:06, 37.03ba/s]
 37%|###7      | 147/393 [00:04<00:06, 38.52ba/s]
 39%|###8      | 152/393 [00:04<00:06, 39.12ba/s]
 40%|###9      | 157/393 [00:04<00:07, 30.16ba/s]
 41%|####1     | 162/393 [00:04<00:07, 32.90ba/s]
 42%|####2     | 167/393 [00:04<00:06, 35.04ba/s]
 44%|####3     | 172/393 [00:04<00:05, 37.08ba/s]
 45%|####5     | 177/393 [00:05<00:05, 38.29ba/s]
 46%|####6     | 182/393 [00:05<00:05, 39.49ba/s]
 48%|####7     | 187/393 [00:05<00:06, 30.36ba/s]
 49%|####8     | 192/393 [00:05<00:06, 33.13ba/s]
 50%|#####     | 197/393 [00:05<00:05, 35.58ba/s]
 51%|#####1    | 202/393 [00:05<00:05, 37.50ba/s]
 53%|#####2    | 207/393 [00:05<00:04, 38.75ba/s]
 54%|#####3    | 212/393 [00:06<00:06, 29.57ba/s]
 55%|#####5    | 217/393 [00:06<00:05, 32.62ba/s]
 56%|#####6    | 222/393 [00:06<00:04, 35.03ba/s]
 58%|#####7    | 227/393 [00:06<00:04, 37.01ba/s]
 59%|#####9    | 232/393 [00:06<00:04, 38.34ba/s]
 60%|######    | 237/393 [00:06<00:03, 39.52ba/s]
 62%|######1   | 242/393 [00:06<00:04, 30.50ba/s]
 63%|######2   | 247/393 [00:07<00:04, 33.08ba/s]
 64%|######4   | 252/393 [00:07<00:04, 35.24ba/s]
 65%|######5   | 257/393 [00:07<00:03, 37.00ba/s]
 67%|######6   | 262/393 [00:07<00:03, 38.54ba/s]
 68%|######7   | 267/393 [00:07<00:04, 29.80ba/s]
 69%|######9   | 272/393 [00:07<00:03, 32.79ba/s]
 70%|#######   | 277/393 [00:07<00:03, 35.37ba/s]
 72%|#######1  | 282/393 [00:08<00:02, 37.15ba/s]
 73%|#######3  | 287/393 [00:08<00:02, 38.36ba/s]
 74%|#######4  | 292/393 [00:08<00:03, 29.86ba/s]
 76%|#######5  | 297/393 [00:08<00:02, 32.73ba/s]
 77%|#######6  | 302/393 [00:08<00:02, 34.93ba/s]
 78%|#######8  | 307/393 [00:08<00:02, 36.64ba/s]
 79%|#######9  | 312/393 [00:08<00:02, 37.74ba/s]
 81%|########  | 317/393 [00:09<00:01, 38.95ba/s]
 82%|########1 | 322/393 [00:09<00:02, 29.94ba/s]
 83%|########3 | 327/393 [00:09<00:02, 32.59ba/s]
 84%|########4 | 332/393 [00:09<00:01, 35.09ba/s]
 86%|########5 | 337/393 [00:09<00:01, 36.80ba/s]
 87%|########7 | 342/393 [00:09<00:01, 38.02ba/s]
 88%|########8 | 347/393 [00:10<00:01, 29.66ba/s]
 90%|########9 | 352/393 [00:10<00:01, 32.44ba/s]
 91%|######### | 357/393 [00:10<00:01, 34.51ba/s]
 92%|#########2| 362/393 [00:10<00:00, 36.56ba/s]
 93%|#########3| 367/393 [00:10<00:00, 37.97ba/s]
 95%|#########4| 372/393 [00:10<00:00, 38.81ba/s]
 96%|#########5| 377/393 [00:10<00:00, 29.98ba/s]
 97%|#########7| 382/393 [00:11<00:00, 32.52ba/s]
 98%|#########8| 387/393 [00:11<00:00, 34.71ba/s]
100%|#########9| 392/393 [00:11<00:00, 36.52ba/s]
100%|##########| 393/393 [00:11<00:00, 34.83ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]
 50%|#####     | 5/10 [00:00<00:00, 43.46ba/s]
100%|##########| 10/10 [00:00<00:00, 43.33ba/s]
100%|##########| 10/10 [00:00<00:00, 43.30ba/s]

  0%|          | 0/10 [00:00<?, ?ba/s]
 20%|##        | 2/10 [00:00<00:00, 11.27ba/s]
 70%|#######   | 7/10 [00:00<00:00, 26.35ba/s]
100%|##########| 10/10 [00:00<00:00, 27.33ba/s]

Training function & evaluation function.

import functools
import time

import torch.nn.functional as F
from datasets import load_metric
from transformers.modeling_outputs import SequenceClassifierOutput


def training(model: torch.nn.Module,
             optimizer: torch.optim.Optimizer,
             criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
             lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
             max_steps: int = None,
             max_epochs: int = None,
             train_dataloader: DataLoader = None,
             distillation: bool = False,
             teacher_model: torch.nn.Module = None,
             distil_func: Callable = None,
             log_path: str = Path(log_dir) / 'training.log',
             save_best_model: bool = False,
             save_path: str = None,
             evaluation_func: Callable = None,
             eval_per_steps: int = 1000,
             device=None):

    assert train_dataloader is not None

    model.train()
    if teacher_model is not None:
        teacher_model.eval()
    current_step = 0
    best_result = 0

    total_epochs = max_steps // len(train_dataloader) + 1 if max_steps else max_epochs if max_epochs else 3
    total_steps = max_steps if max_steps else total_epochs * len(train_dataloader)

    print(f'Training {total_epochs} epochs, {total_steps} steps...')

    for current_epoch in range(total_epochs):
        for batch in train_dataloader:
            if current_step >= total_steps:
                return
            batch.to(device)
            outputs = model(**batch)
            loss = outputs.loss

            if distillation:
                assert teacher_model is not None
                with torch.no_grad():
                    teacher_outputs = teacher_model(**batch)
                distil_loss = distil_func(outputs, teacher_outputs)
                loss = 0.1 * loss + 0.9 * distil_loss

            loss = criterion(loss, None)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # per step schedule
            if lr_scheduler:
                lr_scheduler.step()

            current_step += 1

            if current_step % eval_per_steps == 0 or current_step % len(train_dataloader) == 0:
                result = evaluation_func(model) if evaluation_func else None
                with (log_path).open('a+') as f:
                    msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())), current_epoch, current_step, result)
                    f.write(msg)
                # if it's the best model, save it.
                if save_best_model and (result is None or best_result < result['default']):
                    assert save_path is not None
                    torch.save(model.state_dict(), save_path)
                    best_result = None if result is None else result['default']


def distil_loss_func(stu_outputs: SequenceClassifierOutput, tea_outputs: SequenceClassifierOutput, encoder_layer_idxs=[]):
    encoder_hidden_state_loss = []
    for i, idx in enumerate(encoder_layer_idxs[:-1]):
        encoder_hidden_state_loss.append(F.mse_loss(stu_outputs.hidden_states[i], tea_outputs.hidden_states[idx]))
    logits_loss = F.kl_div(F.log_softmax(stu_outputs.logits / 2, dim=-1), F.softmax(tea_outputs.logits / 2, dim=-1), reduction='batchmean') * (2 ** 2)

    distil_loss = 0
    for loss in encoder_hidden_state_loss:
        distil_loss += loss
    distil_loss += logits_loss
    return distil_loss


def evaluation(model: torch.nn.Module, validation_dataloaders: Dict[str, DataLoader] = None, device=None):
    assert validation_dataloaders is not None
    training = model.training
    model.eval()

    is_regression = task_name == 'stsb'
    metric = load_metric('glue', task_name)

    result = {}
    default_result = 0
    for val_name, validation_dataloader in validation_dataloaders.items():
        for batch in validation_dataloader:
            batch.to(device)
            outputs = model(**batch)
            predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
            metric.add_batch(
                predictions=predictions,
                references=batch['labels'],
            )
        result[val_name] = metric.compute()
        default_result += result[val_name].get('f1', result[val_name].get('accuracy', 0))
    result['default'] = default_result / len(result)

    model.train(training)
    return result


evaluation_func = functools.partial(evaluation, validation_dataloaders=validation_dataloaders, device=device)


def fake_criterion(loss, _):
    return loss

Prepare pre-trained model and finetuning on downstream task.

from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from transformers import BertForSequenceClassification


def create_pretrained_model():
    is_regression = task_name == 'stsb'
    num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
    model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
    model.bert.config.output_hidden_states = True
    return model


def create_finetuned_model():
    finetuned_model = create_pretrained_model().to(device)
    finetuned_model_state_path = Path(model_dir) / 'finetuned_model_state.pth'

    if finetuned_model_state_path.exists():
        finetuned_model.load_state_dict(torch.load(finetuned_model_state_path, map_location=device))
    elif dev_mode:
        pass
    else:
        steps_per_epoch = len(train_dataloader)
        training_epochs = 3
        optimizer = Adam(finetuned_model.parameters(), lr=3e-5, eps=1e-8)

        def lr_lambda(current_step: int):
            return max(0.0, float(training_epochs * steps_per_epoch - current_step) / float(training_epochs * steps_per_epoch))

        lr_scheduler = LambdaLR(optimizer, lr_lambda)
        training(finetuned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler,
                 max_epochs=training_epochs, train_dataloader=train_dataloader, log_path=log_dir / 'finetuning_on_downstream.log',
                 save_best_model=True, save_path=finetuned_model_state_path, evaluation_func=evaluation_func, device=device)
    return finetuned_model


finetuned_model = create_finetuned_model()
Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]
Downloading:   2%|2         | 9.66M/440M [00:00<00:04, 96.6MB/s]
Downloading:   5%|4         | 21.3M/440M [00:00<00:03, 108MB/s]
Downloading:   7%|7         | 32.9M/440M [00:00<00:03, 112MB/s]
Downloading:  10%|#         | 44.5M/440M [00:00<00:03, 113MB/s]
Downloading:  13%|#2        | 56.3M/440M [00:00<00:03, 115MB/s]
Downloading:  15%|#5        | 68.1M/440M [00:00<00:03, 116MB/s]
Downloading:  18%|#8        | 79.8M/440M [00:00<00:03, 116MB/s]
Downloading:  21%|##        | 91.5M/440M [00:00<00:02, 117MB/s]
Downloading:  23%|##3       | 103M/440M [00:00<00:02, 117MB/s]
Downloading:  26%|##6       | 115M/440M [00:01<00:02, 117MB/s]
Downloading:  29%|##8       | 127M/440M [00:01<00:02, 117MB/s]
Downloading:  31%|###1      | 138M/440M [00:01<00:02, 117MB/s]
Downloading:  34%|###4      | 150M/440M [00:01<00:02, 117MB/s]
Downloading:  37%|###6      | 162M/440M [00:01<00:02, 117MB/s]
Downloading:  39%|###9      | 173M/440M [00:01<00:02, 117MB/s]
Downloading:  42%|####2     | 185M/440M [00:01<00:02, 117MB/s]
Downloading:  45%|####4     | 197M/440M [00:01<00:02, 117MB/s]
Downloading:  47%|####7     | 208M/440M [00:01<00:01, 116MB/s]
Downloading:  50%|####9     | 220M/440M [00:01<00:01, 116MB/s]
Downloading:  53%|#####2    | 232M/440M [00:02<00:01, 117MB/s]
Downloading:  55%|#####5    | 243M/440M [00:02<00:01, 117MB/s]
Downloading:  58%|#####7    | 255M/440M [00:02<00:01, 116MB/s]
Downloading:  61%|######    | 267M/440M [00:02<00:01, 116MB/s]
Downloading:  63%|######3   | 278M/440M [00:02<00:01, 116MB/s]
Downloading:  66%|######5   | 290M/440M [00:02<00:01, 116MB/s]
Downloading:  68%|######8   | 302M/440M [00:02<00:01, 116MB/s]
Downloading:  71%|#######1  | 313M/440M [00:02<00:01, 116MB/s]
Downloading:  74%|#######3  | 325M/440M [00:02<00:00, 116MB/s]
Downloading:  76%|#######6  | 337M/440M [00:02<00:00, 116MB/s]
Downloading:  79%|#######9  | 348M/440M [00:03<00:00, 116MB/s]
Downloading:  82%|########1 | 360M/440M [00:03<00:00, 117MB/s]
Downloading:  84%|########4 | 372M/440M [00:03<00:00, 116MB/s]
Downloading:  87%|########7 | 383M/440M [00:03<00:00, 116MB/s]
Downloading:  90%|########9 | 395M/440M [00:03<00:00, 116MB/s]
Downloading:  92%|#########2| 407M/440M [00:03<00:00, 116MB/s]
Downloading:  95%|#########4| 418M/440M [00:03<00:00, 116MB/s]
Downloading:  98%|#########7| 430M/440M [00:03<00:00, 116MB/s]
Downloading: 100%|##########| 440M/440M [00:03<00:00, 116MB/s]
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Pruning

According to experience, it is easier to achieve good results by pruning the attention part and the FFN part in stages. Of course, pruning together can also achieve the similar effect, but more parameter adjustment attempts are required. So in this section, we do pruning in stages.

First, we prune the attention layer with MovementPruner.

steps_per_epoch = len(train_dataloader)

# Set training steps/epochs for pruning.

if not dev_mode:
    total_epochs = 4
    total_steps = total_epochs * steps_per_epoch
    warmup_steps = 1 * steps_per_epoch
    cooldown_steps = 1 * steps_per_epoch
else:
    total_epochs = 1
    total_steps = 3
    warmup_steps = 1
    cooldown_steps = 1

# Initialize evaluator used by MovementPruner.

import nni
from nni.compression.pytorch import TorchEvaluator

movement_training = functools.partial(training, train_dataloader=train_dataloader,
                                      log_path=log_dir / 'movement_pruning.log',
                                      evaluation_func=evaluation_func, device=device)
traced_optimizer = nni.trace(Adam)(finetuned_model.parameters(), lr=3e-5, eps=1e-8)

def lr_lambda(current_step: int):
    if current_step < warmup_steps:
        return float(current_step) / warmup_steps
    return max(0.0, float(total_steps - current_step) / float(total_steps - warmup_steps))

traced_scheduler = nni.trace(LambdaLR)(traced_optimizer, lr_lambda)
evaluator = TorchEvaluator(movement_training, traced_optimizer, fake_criterion, traced_scheduler)

# Apply block-soft-movement pruning on attention layers.
# Note that block sparse is introduced by `sparse_granularity='auto'`, and only support `bert`, `bart`, `t5` right now.

from nni.compression.pytorch.pruning import MovementPruner

config_list = [{
    'op_types': ['Linear'],
    'op_partial_names': ['bert.encoder.layer.{}.attention'.format(i) for i in range(layers_num)],
    'sparsity': 0.1
}]

pruner = MovementPruner(model=finetuned_model,
                        config_list=config_list,
                        evaluator=evaluator,
                        training_epochs=total_epochs,
                        training_steps=total_steps,
                        warm_up_step=warmup_steps,
                        cool_down_beginning_step=total_steps - cooldown_steps,
                        regular_scale=10,
                        movement_mode='soft',
                        sparse_granularity='auto')
_, attention_masks = pruner.compress()
pruner.show_pruned_weights()

torch.save(attention_masks, Path(log_dir) / 'attention_masks.pth')
Did not bind any model, no need to unbind model.
Training 1 epochs, 3 steps...
You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
/anaconda/lib/python3.9/site-packages/torch/optim/lr_scheduler.py:122: UserWarning: Seems like `optimizer.step()` has been overridden after learning rate scheduler initialization. Please, make sure to call `optimizer.step()` before `lr_scheduler.step()`. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
  warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
Did not bind any model, no need to unbind model.

Load a new finetuned model to do speedup, you can think of this as using the finetuned state to initialize the pruned model weights. Note that nni speedup don't support replacing attention module, so here we manully replace the attention module.

If the head is entire masked, physically prune it and create config_list for FFN pruning.

attention_pruned_model = create_finetuned_model().to(device)
attention_masks = torch.load(Path(log_dir) / 'attention_masks.pth')

ffn_config_list = []
layer_remained_idxs = []
module_list = []
for i in range(0, layers_num):
    prefix = f'bert.encoder.layer.{i}.'
    value_mask: torch.Tensor = attention_masks[prefix + 'attention.self.value']['weight']
    head_mask = (value_mask.reshape(heads_num, -1).sum(-1) == 0.)
    head_idxs = torch.arange(len(head_mask))[head_mask].long().tolist()
    print(f'layer {i} prune {len(head_idxs)} head: {head_idxs}')
    if len(head_idxs) != heads_num:
        attention_pruned_model.bert.encoder.layer[i].attention.prune_heads(head_idxs)
        module_list.append(attention_pruned_model.bert.encoder.layer[i])
        # The final ffn weight remaining ratio is the half of the attention weight remaining ratio.
        # This is just an empirical configuration, you can use any other method to determine this sparsity.
        sparsity = 1 - (1 - len(head_idxs) / heads_num) * 0.5
        # here we use a simple sparsity schedule, we will prune ffn in 12 iterations, each iteration prune `sparsity_per_iter`.
        sparsity_per_iter = 1 - (1 - sparsity) ** (1 / 12)
        ffn_config_list.append({
            'op_names': [f'bert.encoder.layer.{len(layer_remained_idxs)}.intermediate.dense'],
            'sparsity': sparsity_per_iter
        })
        layer_remained_idxs.append(i)

attention_pruned_model.bert.encoder.layer = torch.nn.ModuleList(module_list)
distil_func = functools.partial(distil_loss_func, encoder_layer_idxs=layer_remained_idxs)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
layer 0 prune 0 head: []
layer 1 prune 0 head: []
layer 2 prune 0 head: []
layer 3 prune 0 head: []
layer 4 prune 0 head: []
layer 5 prune 0 head: []
layer 6 prune 0 head: []
layer 7 prune 0 head: []
layer 8 prune 0 head: []
layer 9 prune 0 head: []
layer 10 prune 0 head: []
layer 11 prune 0 head: []

Retrain the attention pruned model with distillation.

if not dev_mode:
    total_epochs = 5
    total_steps = None
    distillation = True
else:
    total_epochs = 1
    total_steps = 1
    distillation = False

teacher_model = create_finetuned_model()
optimizer = Adam(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)

def lr_lambda(current_step: int):
    return max(0.0, float(total_epochs * steps_per_epoch - current_step) / float(total_epochs * steps_per_epoch))

lr_scheduler = LambdaLR(optimizer, lr_lambda)
at_model_save_path = log_dir / 'attention_pruned_model_state.pth'
training(attention_pruned_model, optimizer, fake_criterion, lr_scheduler=lr_scheduler, max_epochs=total_epochs,
         max_steps=total_steps, train_dataloader=train_dataloader, distillation=distillation, teacher_model=teacher_model,
         distil_func=distil_func, log_path=log_dir / 'retraining.log', save_best_model=True, save_path=at_model_save_path,
         evaluation_func=evaluation_func, device=device)

if not dev_mode:
    attention_pruned_model.load_state_dict(torch.load(at_model_save_path))
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Training 1 epochs, 1 steps...

Iterative pruning FFN with TaylorFOWeightPruner in 12 iterations. Finetuning 3000 steps after each pruning iteration, then finetuning 2 epochs after pruning finished.

NNI will support per-step-pruning-schedule in the future, then can use an pruner to replace the following code.

if not dev_mode:
    total_epochs = 7
    total_steps = None
    taylor_pruner_steps = 1000
    steps_per_iteration = 3000
    total_pruning_steps = 36000
    distillation = True
else:
    total_epochs = 1
    total_steps = 6
    taylor_pruner_steps = 2
    steps_per_iteration = 2
    total_pruning_steps = 4
    distillation = False

from nni.compression.pytorch.pruning import TaylorFOWeightPruner
from nni.compression.pytorch.speedup import ModelSpeedup

distil_training = functools.partial(training, train_dataloader=train_dataloader, distillation=distillation,
                                    teacher_model=teacher_model, distil_func=distil_func, device=device)
traced_optimizer = nni.trace(Adam)(attention_pruned_model.parameters(), lr=3e-5, eps=1e-8)
evaluator = TorchEvaluator(distil_training, traced_optimizer, fake_criterion)

current_step = 0
best_result = 0
init_lr = 3e-5

dummy_input = torch.rand(8, 128, 768).to(device)

attention_pruned_model.train()
for current_epoch in range(total_epochs):
    for batch in train_dataloader:
        if total_steps and current_step >= total_steps:
            break
        # pruning with TaylorFOWeightPruner & reinitialize optimizer
        if current_step % steps_per_iteration == 0 and current_step < total_pruning_steps:
            check_point = attention_pruned_model.state_dict()
            pruner = TaylorFOWeightPruner(attention_pruned_model, ffn_config_list, evaluator, taylor_pruner_steps)
            _, ffn_masks = pruner.compress()
            renamed_ffn_masks = {}
            # rename the masks keys, because we only speedup the bert.encoder
            for model_name, targets_mask in ffn_masks.items():
                renamed_ffn_masks[model_name.split('bert.encoder.')[1]] = targets_mask
            pruner._unwrap_model()
            attention_pruned_model.load_state_dict(check_point)
            ModelSpeedup(attention_pruned_model.bert.encoder, dummy_input, renamed_ffn_masks).speedup_model()
            optimizer = Adam(attention_pruned_model.parameters(), lr=init_lr)

        batch.to(device)
        # manually schedule lr
        for params_group in optimizer.param_groups:
            params_group['lr'] = (1 - current_step / (total_epochs * steps_per_epoch)) * init_lr

        outputs = attention_pruned_model(**batch)
        loss = outputs.loss

        # distillation
        if distillation:
            assert teacher_model is not None
            with torch.no_grad():
                teacher_outputs = teacher_model(**batch)
            distil_loss = distil_func(outputs, teacher_outputs)
            loss = 0.1 * loss + 0.9 * distil_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        current_step += 1

        if current_step % 1000 == 0 or current_step % len(train_dataloader) == 0:
            result = evaluation_func(attention_pruned_model)
            with (log_dir / 'ffn_pruning.log').open('a+') as f:
                msg = '[{}] Epoch {}, Step {}: {}\n'.format(time.asctime(time.localtime(time.time())),
                                                            current_epoch, current_step, result)
                f.write(msg)
            if current_step >= total_pruning_steps and best_result < result['default']:
                torch.save(attention_pruned_model, log_dir / 'best_model.pth')
                best_result = result['default']
Did not bind any model, no need to unbind model.
Training 1 epochs, 2 steps...
no multi-dimension masks found.
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
/anaconda/lib/python3.9/site-packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:417.)
  return self._grad
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
Did not bind any model, no need to unbind model.
Training 1 epochs, 2 steps...
no multi-dimension masks found.
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"
throw some args away when calling the function "view"

Result

The speedup is test on the entire validation dataset with batch size 128 on A100. We test under two pytorch version and found the latency varying widely.

Setting 1: pytorch 1.12.1

Setting 2: pytorch 1.10.0

Prune Bert-base-uncased on MNLI

Attention Pruning Method

FFN Pruning Method

Total Sparsity

Accuracy

Acc. Drop

Speedup (S1)

Speedup (S2)

85.1M (-0.0%)

84.85 / 85.28

+0.0 / +0.0

25.60s (x1.00)

8.10s (x1.00)

Movement Pruner (soft, sparsity=0.1, regular_scale=1)

Taylor FO Weight Pruner

54.1M (-36.43%)

85.38 / 85.41

+0.53 / +0.13

17.93s (x1.43)

7.22s (x1.12)

Movement Pruner (soft, sparsity=0.1, regular_scale=5)

Taylor FO Weight Pruner

37.1M (-56.40%)

84.73 / 85.12

-0.12 / -0.16

12.83s (x2.00)

5.61s (x1.44)

Movement Pruner (soft, sparsity=0.1, regular_scale=10)

Taylor FO Weight Pruner

24.1M (-71.68%)

84.14 / 84.78

-0.71 / -0.50

8.93s (x2.87)

4.55s (x1.78)

Movement Pruner (soft, sparsity=0.1, regular_scale=20)

Taylor FO Weight Pruner

14.3M (-83.20%)

83.26 / 82.96

-1.59 / -2.32

5.98s (x4.28)

3.56s (x2.28)

Movement Pruner (soft, sparsity=0.1, regular_scale=30)

Taylor FO Weight Pruner

9.9M (-88.37%)

82.22 / 82.19

-2.63 / -3.09

4.36s (x5.88)

3.12s (x2.60)

Movement Pruner (soft, sparsity=0.1, regular_scale=40)

Taylor FO Weight Pruner

8.8M (-89.66%)

81.64 / 82.39

-3.21 / -2.89

3.88s (x6.60)

2.81s (x2.88)

Total running time of the script: ( 1 minutes 32.808 seconds)

Gallery generated by Sphinx-Gallery