MobileBERT — A task agnostic BERT for resource-limited devices ☎️
As the size of NLP models increases into the hundreds of billions of parameters, so does the importance of being able to create more compact representations of these models. Knowledge distillation has successfully enabled this wherein one instance 96% of the teacher’s performance was retained in a 7x smaller model. However, knowledge distillation is still considered an afterthought when designing the teacher models which could reduce the effectiveness, leaving potential performance improvements for the student on the table.
Furthermore, the difficulties in fine-tuning small student models after the initial distillation, without degrading their performance, requires us to both pre-train and fine-tune the teachers on the tasks we want the student to be able to perform. Training a student model through knowledge distillation will, therefore, require more training compared to only training the teacher, which limits the benefits of a student model to inference-time.
What would be possible if, instead, knowledge distillation was put front and center during the design and training of the teacher model? Could we design and successfully train a model that is supposed to be distilled, and could the distilled version successfully be fine-tuned on any downstream task? These are some of the questions addressed in MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited Devices which we summarize in this article.
The MobileBERT architectures
Knowledge distillation requires us to compare teacher and student representations so that the difference between these can be minimised. This is straight forward when both matrices or vectors are of the same dimension. Therefore, MobileBERT introduces a bottleneck layer into the transformer block. This allows the input to both student and teacher to be equivalent in size while their internal representations can differ. These bottlenecks are shown as green trapezoids marked with “Linear” in the figure above. In this particular case, the shared dimension is 512, while the internal representation sizes for the teacher and student are 1024 and 128 respectively. This allows us to use a BERT-large (340M parameters) equivalent model to train a 25M parameter student.
Further, since the input and output dimensions of each transformer block are the same for both models, it is possible to transfer both embedding and classifier parameters from teacher to student simply by copying them!
The observant reader will have noticed that the input to the Multi-Head Attention block (MHA) is not the output from the prior linear projection. Instead, the initial input is used. There is no motivation for this design choice in the paper, which leaves us to speculate. I believe the reason is the increased dimensions of freedom it allows. Basically, we separate how the model is forced to process the information into two separate streams, one fed into the MHA block and the other as a skip-connection. (It’s also quite easy to convince oneself that using the output from the linear projection does not change the behavior of the MHA block due to its initial linear transform.)
To achieve high enough capacity within the small student, the authors introduce what they call stacked FFN, shown as a dashed box within the student model overview in the image above. Stacked FFN’s simply repeats the Feed Forward + Add & Norm blocks 4 times, which is chosen to achieve a good parameter ratio between MHA and FFN blocks. Ablation studies in this work show that the best performance is achieved when this ratio is in the range of 0.4–0.6.
Due to one of the goals being to enable fast inference on resource-limited devices the authors identify two areas where their architecture could be further improved.
- Replace the smooth GeLU activation function for ReLU
- Swap the normalization operation for an element-wise linear transformation
Proposed knowledge distillation objectives
To achieve knowledge transfer between the proposed teacher and student, the authors apply knowledge distillation at three stages of the model:
- Feature map transfer — Allows the student to mimic the teacher at each transformer layer output. In the architecture figure above, this is shown as a dashed arrow between the output of the models.
- Attention map transfer — Which tokens the teacher attends to at different layers and heads is another important property we want the student to learn. This is enabled by minimizing the difference between the attention distributions (the KL-divergence) at each layer and head.
- Pre-training distillation — It is also possible to use distillation during pre-training through combining both Masked Language Modelling and Next Sentence Prediction tasks in a linear combination.
With these objectives, there is more than one way we can perform knowledge distillation. The authors propose three alternatives:
- Auxiliary knowledge transfer. The layer-wise knowledge transfer objectives are minimised together with the main objectives — masked language modelling and next sentence prediction. This can be considered the most simple approach.
- Joint knowledge transfer. Instead of trying to achieve all objectives at once, it is possible to separate knowledge distillation and pre-training into two stages of training. First, all layer-wise knowledge distillation losses are trained until convergence and then further training with the pre-training objective is performed.
- Progressive knowledge transfer. The two-step approach can be taken even further. Errors not yet minimized properly in early layers will propagate and affect the training of later layers if all layers are trained simultaneously. It might, therefore, be better to train one layer at a time while freezing or reducing the learning rate of previous layers
The authors evaluate their proposed MobileBERT in three configurations: the main model with 25M parameters (MobileBERT), the same model without the operational optimizations (MobileBERT w/o OPT), as well as a model with only 15M parameters (MobileBERT-tiny). These models were compared to both baseline algorithms such as ELMo, GPT and BERT-base as well as related distillation work: BERT-PKD, DistilBERT and TinyBERT.
Training these variations of MobileBERT was found to be most effective through the progressive knowledge transfer process, which consistently outperformed the other two by a significant margin.
What we find is that MobileBERT w/o OPT outperforms the much larger BERT-base by 0.2 average GLUE score, while being 4x smaller. MobileBERT, on the other hand, is only 0.6 points behind BERT-base while having a much faster inference time — 62 ms for a sequence of 128 tokens on a Pixel 4 phone! Its performance is however still competitive since it outperforms GTP and ELMo by a significant margin.
Therefore, it’s safe to conclude that it’s possible to create a distilled model which both can be performant and fast on resource-limited devices!
MobileBERT-tiny achieves slightly better performance compared to TinyBERT. This does however become even more impressive when you consider how TinyBERT was fine-tuned for the GLUE tasks. Remember, prior to this work it was not possible to fine-tune the students due to their small capacity. TinyBERT’s teacher BERT-base, therefore, had to be fine-tuned before its knowledge could be distilled into TinyBERT! That is not the case for MobileBERT.
MobileBERT has been fine-tuned by itself on GLUE which proves that it’s possible to create a task agnostic model through the proposed distillation process!
MobileBERT introduces bottlenecks in the transformer blocks, which allows us to more easily distil the knowledge from larger teachers into smaller students. This technique allows us to reduce the width rather than the depth of the student, which is known to produce a more capable model. This model highlight the fact that it’s possible to create a student model that by itself can be fine-tuned after the initial distillation process.
Further, the results also show that this holds true in practice too as MobileBERT is able to reach 99.2% of BERT-base’s performance on GLUE with 4x fewer parameters and 5.5x faster inference on a Pixel 4 phone.
If you found this summary helpful in understanding the broader picture of this particular research paper, please consider reading my other articles! I’ve already written a bunch and more will definitely be added. I think you might find this one interesting 👋🏼🤖