Poor Man’s BERT — Exploring Pruning as an Alternative to Knowledge Distillation

Viktor
DAIR.AI
Published in
6 min readJul 26, 2020

--

The ever increasing size of NLP models, and the reduced usability that ensues, is something I’ve discussed in many previous paper summaries (TinyBERT, MobileBERT, and DistilBERT are some of these). Each of these papers proposes a unique knowledge distillation framework with the common goal of reducing model size while preserving performance.

While these methods all have been successful in their respective ways, there exists a common drawback: knowledge distillation requires additional training after an already expensive teacher training which limits the usefulness of these techniques to inference time.

An alternative method to knowledge distillation, which provides a simple solution to this issue, is pruning. Previous works (Are Sixteen Heads Really Better than One?” and “Analyzing multi-head self-attention: Specialized heads do the heavy lifting, the rest can be pruned”) have shown that it is possible for transformer based architectures to drop some attention heads during inference, without significant reduction of performance.

Continuing this line of thought, what would happen with model performance if we drop entire transformer layers from our pre-trained model? Would the resulting model be usable for further fine-tuning? Does the performance differ depending on which layers we drop? These are some of the questions the authors in Poor Man’s BERT: Smaller and Faster Transformer Models analyze, which I will provide a summary of in this article. Let’s get into their contributions!

Contribution

Pruning techniques

To thoroughly examine the effect of dropping transformer layers, the authors propose five different strategies:

1. Top-layer dropping. Remove the last layers of the network. Previous work has shown that later layers in the network specialize for the pre-training objective, which might not be helpful during the fine-tuning stage of training.

2. Bottom-layer dropping. Remove the initial layers from the network. This form of pruning is included for completeness, even though previous work have shown that initial layers model the local interaction between tokens.

3. Alternate dropping. Removes every other layer, either the even or odd ones, counting from the end of the network. These techniques aim to answer whether adjacent layers learn similar enough transforms for one of them to be removed.

4. Symmetric dropping. Remove layers from the center of the network. The motivation is that these layers learn less important features compared to the bottom and top layers.

5. Contribution-based dropping. Removes layers based on how much they modify their input. This is measured by averaging the cosine similarity between input and output embeddings. It seems reasonable to remove layers where this similarity is high, as that would indicate small changes.

Schematic overview of the proposed pruning techniques. Source

Benefits of proposed methods

One of the main benefits of applying these pruning techniques is that the resulting model does not require any further pre-training. The authors suggest that it is sufficient to only fine-tune the model on the specific task.

This allows practitioners without access to massive amounts of computing hardware to easily create smaller versions already existing, pre-trained models.

Results

The pruning techniques described above were evaluated for three models — the 12 layer models BERT-base and XLNet and the 6 layers DistilBERT. DistilBERT also enabled further comparison between the proposed pruning approaches and similar knowledge distillation techniques.

Best pruning technique

We find that top-layer dropping outperforms the other pruning techniques, especially when removing 4 or 6 layers on the GLUE benchmark. In the latter case, when half of the model’s layers are removed, the performance was only degraded by 2.9 and 1.8 GLUE points for BERT and XLNet respectively. This matches the performance of DistillBERT, which has comparable size to these pruned models.

Performance for the different pruning techniques for BERT and XLNet while dropping 2, 4 or 6 layers. Source

Pruning six top-layers in either BERT or XLNet results in a model that match both DistillBERT’s performance and size, while not requiring any specific training procedure.

Dropping layers from DistillBERT also result in high-performant models, where dropping either one or two layers perform comparably to the original model. Again, top-layer dropping is most consistent, while both (even and odd) alternate dropping methods perform competitively.

Task-specific results

Since top-level pruning proved to be the best alternative, the following experiments were limited to that method.

Instead of studying what happens when removing a fixed set of layers, it’s possible to approach this question from another angle — given that we accept a certain performance drop, how many layers are we allowed to drop?

Accepting a performance drop of either 1, 2, or 3 % reveal that it’s possible to drop, up to, 9 (!) layers for both BERT and XLNet (and 4 in the case of DistilBERT) for some of the GLUE tasks! This, if anything, should be a strong indication of the over parameterization the transformer models struggle with.

It is possible, for some tasks, to prune up to 9 of the top layers from a 12 layer model without degrading the performance more than 3%.

BERT vs XLNet

The authors also provide a detailed comparison of BERT and XLNet in regards to their ability to be pruned. Here’s what they found:

XLNet is more robust to pruning of its top layers. This leads them to conclude that XLNet is able to learn more complex, task-specific information earlier in the network. This hypothesis was evaluated by adding classifier heads on each transformer layer for both BERT and XLNET. XLNet reaches its best performance already at layer 7 while BERT needs at least 11 layers to converge, see graph below. This result explains the observed robustness of XLNet.

Classifier performance when applied to all layers in the networks. Source

Fine-tuning affect layers in BERT and XLNet in entirely different ways. While previous work have shown that fine-tuning change later layers of BERT more than early ones, how it affects the layers of XLNet has not been studied. This work verifies previous findings but contrast this with the fact that for XLNet, layers in the middle change much more than both early and late layers after fine-tuning. The graph below shows a comparison between layers before and after fine-tuning for both these models.

Difference, measured through average cosine-similarity, between pre-trained and fine-tuned layers for BERT and XLNET. Source

The authors speculate that the reason for this difference is the pre-training process. To not get stuck in the exact details — XLNet is an auto-regressive (AR) language model trained via all possible permutations of the factorization order which allows it to learn bidirectional context despite being an AR model.

Pruning a fine-tuned model

While asking the question “what happens if we prune an already fine-tuned model” is valid, the authors show that it does not significantly improve performance (it actually results in worse results for BERT).

Conclusion

Knowledge distillation has shown promising results for reducing model size while preserving much of its performance. The main drawback has been that building these models require additional training, which prevents these smaller models to be created by researchers with limited computing resources.

An alternative approach is to prune the models through, simply, removing a set of layers. This work has shown that dropping the top layers provide the most consistent results and that, for some tasks, it is possible to drop 9 out of 12 layers while retaining 97% of the original model’s performance.

Finally, this work provided insights into the differences between BERT and XLNet which indicate that XLNet is a more robust model to be used for pruning. The authors contribute this, in part, due to its novel pre-training objective.

If you found this summary helpful in understanding the broader picture of this particular research paper, please consider reading my other articles! I’ve already written a bunch and more will definitely be added. I think you might find this one interesting👋🏼🤖

--

--

Viktor
DAIR.AI

Learning to write and writing to learn. Staying on top of current NLP research through sharing what I find interesting 🤖