Structured Pruning for Transformer-Based Models

A Few Lines of Code, A Satisfying Sparse Transformer Model

Intel(R) Neural Compressor
Intel Analytics Software
3 min readJan 9, 2023

--

Yiyang Cai, Wenhua Cheng, Dong Bo, Hanwen Chang, and Haihao Shen, Intel Corporation

Neural network pruning is a promising model compression technique. It removes the least important parameters in the network to improve inference performance without sacrificing prediction accuracy.

Neural Network Pruning

In recent years, drawn by the success of attention-based models, Transformers architectures have achieved impressive results in neural language processing and computer vision. However, these models are even more resource intensive, making their deployment impractical on resource-limited systems. Therefore, applying pruning techniques on these models is critical.

In this article, we describe how to prune a transformer model using Intel Neural Compressor.

Structured pruning means finding parameters in groups, deleting entire blocks, filters, or channels according to some pruning criteria. Below is an example of structured pruning on a 4x1 block:

4x1 Structured Pruning

We will use this pattern to illustrate how pruning works, then introduce the Intel Neural Compressor API. This API allows users to define personalized pruning strategies, including different pruning criteria, patterns, and desired network sparsity. We will provide a quick tutorial of pruning a Bert-Mini model on the SQuAD v1.1 dataset.

First, install Intel Neural Compressor:

pip install neural_compressor

Once installed, you are ready to prune your model! You just need to follow these two steps:

Step 1: Define a dict-like configuration in your training codes. We provide a template configuration below (please refer to the documentation for more details):

configs = [
{
'target_sparsity': 0.9, # Target sparsity ratio of modules.
'pruning_type': "snip_momentum", # Default pruning type.
'pattern': "4x1", # Default pruning pattern.
'excluded_op_names': ['embedding.*', 'classifier.*'], # A list of modules that would not be pruned.
}
]

Step 2: Insert API functions into your codes. Only four lines of code are required:


""" All you need is to insert following API functions:
pruner.on_train_begin() # Setup pruner
pruner.on_step_begin() # Prune weights
pruner.on_before_optimizer_step() # Do weight regularization
pruner.on_after_optimizer_step() # Update weights' criteria, mask weights
"""
from neural_compressor.pruner.pruning import Pruning, WeightPruningConfig
config = WeightPruningConfig(configs)
pruner = Pruning(config) # Define a pruning object
pruner.model = model # Set model object to prune
pruner.on_train_begin()
for epoch in range(num_train_epochs):
model.train()
for step, batch in enumerate(train_dataloader):
pruner.on_step_begin(step)
outputs = model(**batch)
loss = outputs.loss
loss.backward()
pruner.on_before_optimizer_step()
optimizer.step()
pruner.on_after_optimizer_step()
lr_scheduler.step()
model.zero_grad()

Finally, run the code as you normally would to obtain a sparse Bert-Mini model!

We provide some 4x1 structured pruning results on Bert series models (including Bert-Mini, Bert-Base, Bert-Large, etc.) below. We can prune these Transformer models to a high sparsity ratio while keeping relative accuracy within 1% of the original model.

To exploit Intel’s hardware benefits for sparse models, we leverage Intel Extension for Transformers, a new toolkit to accelerate Transformer-based models. We demonstrate up to 25x speedup on extremely compressed BERT-Mini over BERT-Base on Intel Xeon Platinum 8380 processors (formerly codenamed Ice Lake, ICX) through distillation, quantization, and sparsity.

We have provided you with an example using our acceleration library to deploy a Bert-Mini model with high sparsity.

Besides Transformer-based models like the BERT series, we also provide pruning examples for a wide range of models in Intel Neural Compressor that you can try.

You can also visit Intel Neural Compressor and Intel Extension for Transformers for more details, or reach out to us on GitHub issues if you have any questions.

--

--