备注
Go to the end to download the full example code
Pruning Bert on Task MNLI¶
This is a new tutorial on pruning transformer in nni v3.0 (old tutorial). The main difference between this tutorial and the previous is that it integrates the feature of fusion compression (pruning + distillation) in nni, uses a new more powerful and stable pruning speedup tool, and additionally prunes the whole model hidden dimensions which greatly reduces the model size (pruning embedding layers).
At the same time, the huggingface transformers.Trainer is used in this tutorial to reduce the burden of user writing training and evaluation logic.
Workable Pruning Process¶
The whole pruning process is divided into three steps:
pruning attention layers,
pruning feed forward layers,
pruning embedding layers.
In each step, the pruner is first used for simulated pruning to generate masks corresponding to the module pruning targets (weight, input, output). After that comes the speedup stage, sparsity propagation is used to explore the global redundancy due to the local masks, then modify the original model into a smaller one by replacing the sub module in the model.
The compression of the model naturally applies the distillation method, so in this tutorial, distillers will also be used to help restore the model accuracy.
Experiment¶
Preparations¶
The preparations mainly includes preparing the transformers trainer and model.
This is generally consistent with the preparations required to normally train a Bert model.
The only difference is that the transformers.Trainer
is needed to wrap by nni.trace
to trace the init arguments,
this is because nni need re-create trainer during training aware pruning and distilling.
备注
Please set skip_exec
to False
to run this tutorial. Here skip_exec
is True
by default is for generating documents.
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import ConcatDataset
import nni
from datasets import load_dataset, load_metric
from transformers import BertTokenizerFast, DataCollatorWithPadding, BertForSequenceClassification, EvalPrediction
from transformers.trainer import Trainer
from transformers.training_args import TrainingArguments
skip_exec = True
Set the downstream task name here, you could replace the task with the task in GLUE.
task_name = 'mnli'
Here using BertForSequenceClassification as the base model for show case. If you want to prune other kind of transformer model, you could replace the base model here.
def build_model(pretrained_model_name_or_path: str, task_name: str):
is_regression = task_name == 'stsb'
num_labels = 1 if is_regression else (3 if task_name == 'mnli' else 2)
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=num_labels)
return model
Prepare the GLUE train & validation datasets, if the task has multi validation datasets, concat the datasets by ConcatDataset
.
def prepare_datasets(task_name: str, tokenizer: BertTokenizerFast, cache_dir: str):
task_to_keys = {
'cola': ('sentence', None),
'mnli': ('premise', 'hypothesis'),
'mrpc': ('sentence1', 'sentence2'),
'qnli': ('question', 'sentence'),
'qqp': ('question1', 'question2'),
'rte': ('sentence1', 'sentence2'),
'sst2': ('sentence', None),
'stsb': ('sentence1', 'sentence2'),
'wnli': ('sentence1', 'sentence2'),
}
sentence1_key, sentence2_key = task_to_keys[task_name]
# used to preprocess the raw data
def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=False, max_length=128, truncation=True)
if 'label' in examples:
# In all cases, rename the column to labels because the model will expect that.
result['labels'] = examples['label']
return result
raw_datasets = load_dataset('glue', task_name, cache_dir=cache_dir)
for key in list(raw_datasets.keys()):
if 'test' in key:
raw_datasets.pop(key)
processed_datasets = raw_datasets.map(preprocess_function, batched=True,
remove_columns=raw_datasets['train'].column_names)
train_dataset = processed_datasets['train']
if task_name == 'mnli':
validation_datasets = {
'validation_matched': processed_datasets['validation_matched'],
'validation_mismatched': processed_datasets['validation_mismatched']
}
else:
validation_datasets = {
'validation': processed_datasets['validation']
}
return train_dataset, validation_datasets
Prepare the trainer, note that the Trainer
class is wrapped by nni.trace
.
def prepare_traced_trainer(model, task_name, load_best_model_at_end=False):
is_regression = task_name == 'stsb'
metric = load_metric('glue', task_name)
def compute_metrics(p: EvalPrediction):
preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
result = metric.compute(predictions=preds, references=p.label_ids)
result['default'] = result.get('f1', result.get('accuracy', 0.))
return result
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
train_dataset, validation_datasets = prepare_datasets(task_name, tokenizer, None)
merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()])
data_collator = DataCollatorWithPadding(tokenizer)
training_args = TrainingArguments(output_dir='./output/trainer',
do_train=True,
do_eval=True,
evaluation_strategy='steps',
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
num_train_epochs=3,
dataloader_num_workers=12,
learning_rate=3e-5,
save_strategy='steps',
save_total_limit=1,
metric_for_best_model='default',
load_best_model_at_end=load_best_model_at_end,
disable_tqdm=True,
optim='adamw_torch',
seed=1024)
trainer = nni.trace(Trainer)(model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=merged_validation_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,)
return trainer
If the finetuned model is existed, directly load it. If the finetuned model is not existed, train the pretrained model with the trainer.
def build_finetuning_model(task_name: str, state_dict_path: str):
model = build_model('bert-base-uncased', task_name)
if Path(state_dict_path).exists():
model.load_state_dict(torch.load(state_dict_path))
else:
trainer = prepare_traced_trainer(model, task_name, True)
trainer.train()
torch.save(model.state_dict(), state_dict_path)
return model
if not skip_exec:
Path('./output/bert_finetuned').mkdir(exist_ok=True, parents=True)
build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
The following code creates distillers for distillation.
from nni.compression.distillation import DynamicLayerwiseDistiller, Adaptive1dLayerwiseDistiller
from nni.compression.utils import TransformersEvaluator
Dynamic distillation is suitable for the situation where the distillation states dimension of the student and the teacher match. A student state can try to distill on multiple teacher states, and finally select the teacher state with the smallest distillation loss as the target for distillation.
In this tutorial, dynamic distillation is applied before speedup the embedding pruning.
def dynamic_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
student_trainer: Trainer):
layer_num = len(student_model.bert.encoder.layer)
config_list = [{
'op_names': [f'bert.encoder.layer.{i}'],
'link': [f'bert.encoder.layer.{j}' for j in range(i, layer_num)],
'lambda': 0.9,
'apply_method': 'mse',
} for i in range(layer_num)]
config_list.append({
'op_names': ['classifier'],
'link': ['classifier'],
'lambda': 0.9,
'apply_method': 'kl',
})
evaluator = TransformersEvaluator(student_trainer)
def teacher_predict(batch, teacher_model):
return teacher_model(**batch)
return DynamicLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
def dynamic_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
max_steps: int | None, max_epochs: int | None):
student_trainer = prepare_traced_trainer(student_model, task_name, True)
ori_teacher_device = teacher_model.device
training = teacher_model.training
teacher_model.to(student_trainer.args.device).eval()
distiller = dynamic_distiller(student_model, teacher_model, student_trainer)
distiller.compress(max_steps, max_epochs)
distiller.unwrap_model()
teacher_model.to(ori_teacher_device).train(training)
Adapt distillation is applied after pruning embedding layers.
The hidden states dimension will mismatch between student model and teacher model after pruning embedding layers,
then adapt distiller will add a linear layer for each distillation module pair to align dimension.
For example, pruning hidden dimension from 768 to 384, then for each student transformer block,
will add a Linear(in_features=384, out_features=768)
for shifting dimention 384 to 768,
aligned with the teacher model transformer block output.
def adapt_distiller(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
student_trainer: Trainer):
layer_num = len(student_model.bert.encoder.layer)
config_list = [{
'op_names': [f'bert.encoder.layer.{i}'],
'lambda': 0.9,
'apply_method': 'mse',
} for i in range(layer_num)]
config_list.append({
'op_names': ['classifier'],
'link': ['classifier'],
'lambda': 0.9,
'apply_method': 'kl',
})
evaluator = TransformersEvaluator(student_trainer)
def teacher_predict(batch, teacher_model):
return teacher_model(**batch)
return Adaptive1dLayerwiseDistiller(student_model, config_list, evaluator, teacher_model, teacher_predict, origin_loss_lambda=0.1)
def adapt_distillation(student_model: BertForSequenceClassification, teacher_model: BertForSequenceClassification,
max_steps: int | None, max_epochs: int | None):
student_trainer = prepare_traced_trainer(student_model, task_name, True)
ori_teacher_device = teacher_model.device
training = teacher_model.training
teacher_model.to(student_trainer.args.device).eval()
distiller = adapt_distiller(student_model, teacher_model, student_trainer)
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
dummy_input = [_.to(student_trainer.args.device) for _ in dummy_input]
distiller.track_forward(*dummy_input)
distiller.compress(max_steps, max_epochs)
distiller.unwrap_model()
teacher_model.to(ori_teacher_device).train(training)
Pruning Attention Layers¶
Here using MovementPruner
to generate block sparse masks. Choosing 64 x 64
block is because the head width is 64,
this is a kind of coarse grained between head pruning and finegrained pruning, also you can have a try with 64 x 32
,
32 x 32
or any other granularity here.
We use sparse_threshold
instead of sparse_ratio
here to apply adaptive sparse allocation.
sparse_threshold
here is a float number between 0. and 1., but its value has little effect on the final sparse ratio.
If you want a more sparse model, you could set a larger regular_scale
in MovementPruner
.
You could refer to the experiment results to choose a appropriate regular_scale
you like.
from nni.compression.pruning import MovementPruner
from nni.compression.speedup import ModelSpeedup
from nni.compression.utils.external.external_replacer import TransformersAttentionReplacer
def pruning_attn():
Path('./output/bert_finetuned/').mkdir(parents=True, exist_ok=True)
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
trainer = prepare_traced_trainer(model, task_name)
evaluator = TransformersEvaluator(trainer)
config_list = [{
'op_types': ['Linear'],
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention\.*'],
'sparse_threshold': 0.1,
'granularity': [64, 64]
}]
pruner = MovementPruner(model, config_list, evaluator, warmup_step=9000, cooldown_begin_step=36000, regular_scale=10)
pruner.compress(None, 4)
pruner.unwrap_model()
masks = pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/attn_masks.pth')
torch.save(model, './output/pruning/attn_masked_model.pth')
if not skip_exec:
pruning_attn()
We apply head pruning during the speedup stage, if the head is fully masked it will be pruned, if the header is partially masked, it will be restored.
def speedup_attn():
model = torch.load('./output/pruning/attn_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/attn_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
replacer = TransformersAttentionReplacer(model)
ModelSpeedup(model, dummy_input, masks, customized_replacers=[replacer]).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
dynamic_distillation(model, teacher_model, None, 3)
torch.save(model, './output/pruning/attn_pruned_model.pth')
if not skip_exec:
speedup_attn()
Pruning Feed Forward Layers¶
Here using TaylorPruner
for pruning feed forward layers,
and the sparse ratio related to the pruned head number in the same transformer block.
The more heads are pruned, the higher the sparse ratio is set for feed forward layers.
Note that TaylorPruner
has no schedule sparse ratio function,
so we use AGPPruner
to schedule the sparse ratio to achieve better pruning performance.
from nni.compression.pruning import TaylorPruner, AGPPruner
from transformers.models.bert.modeling_bert import BertLayer
def pruning_ffn():
model: BertForSequenceClassification = torch.load('./output/pruning/attn_pruned_model.pth')
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
# create ffn config list, here simply use a linear function related to the number of retained heads to determine the sparse ratio
config_list = []
for name, module in model.named_modules():
if isinstance(module, BertLayer):
retained_head_num = module.attention.self.num_attention_heads
ori_head_num = len(module.attention.pruned_heads) + retained_head_num
ffn_sparse_ratio = 1 - retained_head_num / ori_head_num / 2
config_list.append({'op_names': [f'{name}.intermediate.dense'], 'sparse_ratio': ffn_sparse_ratio})
trainer = prepare_traced_trainer(model, task_name)
teacher_model.eval().to(trainer.args.device)
# create a distiller for restoring the accuracy
distiller = dynamic_distiller(model, teacher_model, trainer)
# fusion compress: TaylorPruner + DynamicLayerwiseDistiller
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
# fusion compress: AGPPruner(TaylorPruner) + DynamicLayerwiseDistiller
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
agp_pruner.compress(None, 3)
agp_pruner.unwrap_model()
distiller.unwrap_teacher_model()
masks = agp_pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/ffn_masks.pth')
torch.save(model, './output/pruning/ffn_masked_model.pth')
if not skip_exec:
pruning_ffn()
Speedup the feed forward layers.
def speedup_ffn():
model = torch.load('./output/pruning/ffn_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/ffn_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
ModelSpeedup(model, dummy_input, masks).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
dynamic_distillation(model, teacher_model, None, 3)
torch.save(model, './output/pruning/ffn_pruned_model.pth')
if not skip_exec:
speedup_ffn()
Pruning Embedding Layers¶
We want to simulate the pruning effect better, so we register the output mask setting for BertAttention
and BertOutput
.
The output masks can be generated and applied after register the setting template for them.
from nni.compression.base.setting import PruningSetting
output_align_setting = {
'_output_': {
'align': {
'module_name': None,
'target_name': 'weight',
'dims': [0],
},
'apply_method': 'mul',
}
}
PruningSetting.register('BertAttention', output_align_setting)
PruningSetting.register('BertOutput', output_align_setting)
Similar to prune feed forward layers, we also use AGPPruner + TaylorPruner + DynamicLayerwiseDistiller
here.
For the better pruning effect simulation, set output align
mask generation in config_list
,
then the relevant layers will generate its own output masks according to the embedding masks.
def pruning_embedding():
model: BertForSequenceClassification = torch.load('./output/pruning/ffn_pruned_model.pth')
teacher_model: BertForSequenceClassification = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
sparse_ratio = 0.5
config_list = [{
'op_types': ['Embedding'],
'op_names_re': ['bert\.embeddings.*'],
'sparse_ratio': sparse_ratio,
'dependency_group_id': 1,
'granularity': [-1, 1],
}, {
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention$',
'bert\.encoder\.layer\.[0-9]*\.output$'],
'target_names': ['_output_'],
'target_settings': {
'_output_': {
'align': {
'module_name': 'bert.embeddings.word_embeddings',
'target_name': 'weight',
'dims': [1],
}
}
}
}, {
'op_names_re': ['bert\.encoder\.layer\.[0-9]*\.attention.output.dense',
'bert\.encoder\.layer\.[0-9]*\.output.dense'],
'target_names': ['weight'],
'target_settings': {
'weight': {
'granularity': 'out_channel',
'align': {
'module_name': 'bert.embeddings.word_embeddings',
'target_name': 'weight',
'dims': [1],
}
}
}
}]
trainer = prepare_traced_trainer(model, task_name)
teacher_model.eval().to(trainer.args.device)
distiller = dynamic_distiller(model, teacher_model, trainer)
taylor_pruner = TaylorPruner.from_compressor(distiller, config_list, 1000)
agp_pruner = AGPPruner(taylor_pruner, 1000, 36)
agp_pruner.compress(None, 3)
agp_pruner.unwrap_model()
distiller.unwrap_teacher_model()
masks = agp_pruner.get_masks()
Path('./output/pruning/').mkdir(parents=True, exist_ok=True)
torch.save(masks, './output/pruning/embedding_masks.pth')
torch.save(model, './output/pruning/embedding_masked_model.pth')
if not skip_exec:
pruning_embedding()
Speedup the embedding layers.
def speedup_embedding():
model = torch.load('./output/pruning/embedding_masked_model.pth', map_location='cpu')
masks = torch.load('./output/pruning/embedding_masks.pth', map_location='cpu')
dummy_input = (torch.randint(0, 10000, [8, 128]), torch.randint(0, 2, [8, 128]), torch.randint(0, 2, [8, 128]))
ModelSpeedup(model, dummy_input, masks).speedup_model()
# finetuning
teacher_model = build_finetuning_model('mnli', f'./output/bert_finetuned/{task_name}.bin')
adapt_distillation(model, teacher_model, None, 4)
torch.save(model, './output/pruning/embedding_pruned_model.pth')
if not skip_exec:
speedup_embedding()
Evaluation¶
Evaluate the pruned model size and accuracy.
def evaluate_pruned_model():
model: BertForSequenceClassification = torch.load('./output/pruning/embedding_pruned_model.pth')
trainer = prepare_traced_trainer(model, task_name)
metric = trainer.evaluate()
pruned_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
model = build_finetuning_model(task_name, f'./output/bert_finetuned/{task_name}.bin')
ori_num_params = sum(param.numel() for param in model.parameters()) + sum(buffer.numel() for buffer in model.buffers())
print(f'Metric: {metric}\nSparsity: {1 - pruned_num_params / ori_num_params}')
if not skip_exec:
evaluate_pruned_model()
Results¶
Total Sparsity |
Embedding Sparsity |
Encoder Sparsity |
Pooler Sparsity |
Acc. (m/mm avg.) |
---|---|---|---|---|
0.% |
0.% |
0.% |
0.% |
84.95% |
57.76% |
33.33% (15.89M) |
64.78% (29.96M) |
33.33% (0.39M) |
84.42% |
68.31% (34.70M) |
50.00% (11.92M) |
73.57% (22.48M) |
50.00% (0.30M) |
83.33% |
70.95% (31.81M) |
33.33% (15.89M) |
81.75% (15.52M) |
33.33% (0.39M) |
83.79% |
78.20% (23.86M) |
50.00% (11.92M) |
86.31% (11.65M) |
50.00% (0.30M) |
82.53% |
81.65% (20.12M) |
50.00% (11.92M) |
90.71% (7.90M) |
50.00% (0.30M) |
82.08% |
84.32% (17.17M) |
50.00% (11.92M) |
94.18% (4.95M) |
50.00% (0.30M) |
81.35% |
Total running time of the script: ( 0 minutes 0.020 seconds)