🏎 Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT
You can find the code to reproduce the training of DistilBERT along with pre-trained weights for DistilBERT here.
In the last 18 months, transfer learning from large-scale language models has significantly improved upon the state-of-the-art on pretty much every Natural Language Processing task.
Usually based on the Transformer architecture of Vaswani et al., these pre-trained language models keep getting larger and larger and being trained on bigger datasets. The latest model from Nvidia has 8.3 billion parameters: 24 times larger than BERT-large, 5 times larger than GPT-2, while RoBERTa, the latest work from Facebook AI, was trained on 160GB of text 😵
At Hugging Face, we experienced first-hand the growing popularity of these models as our NLP library — which encapsulates most of them — got installed more than 400,000 times in just a few months.
However, as these models were reaching a larger NLP community, an important and challenging question started to emerge. How should we put these monsters in production? How can we use such large models under low latency constraints? Do we need (costly) GPU servers to serve at scale?
For many researchers and developers, these can be deal-breaking issues 💸
To build more privacy-respecting systems, we noticed an increasing need to have machine learning systems operate on the edge rather than calling a cloud API and sending possibly private data to servers. Running models on devices like your smartphone 📲 also requires light-weight, responsive and energy-efficient models!
Last but not least, we are more and more concerned about the environmental cost of scaling exponentially computing requirements of these models.
So, how can we reduce the size of these monster models⁉️
There are many techniques available to tackle the previous questions. The most common tools include quantization (approximating the weights of a network with a smaller precision) and weights pruning (removing some connections in the network). For these technics, you can have a look at the excellent blog post of Rasa on quantizing BERT.
We decided to focus on distillation: a technique you can use to compress a large model, called the teacher, into a smaller model, called the student.
⚗️ Knowledge Distillation — Transferring generalization capabilities
Knowledge distillation (sometimes also referred to as teacher-student learning) is a compression technique in which a small model is trained to reproduce the behavior of a larger model (or an ensemble of models). It was introduced by Bucila et al. and generalized by Hinton et al. a few years later. We will follow the latter method.
In supervised learning, a classification model is generally trained to predict a gold class by maximizing its probability (softmax of logits) using the log-likelihood signal. In many cases, a good performance model will predict an output distribution with the correct class having a high probability, leaving other classes with probabilities near zero.
But, some of these “almost-zero” probabilities are larger than the others, and this reflects, in part, the generalization capabilities of the model.
For instance, a desk chair might be mistaken with an armchair but should usually not be mistaken with a mushroom. This uncertainty is sometimes referred to as the “dark knowledge” 🌚
Another way to understand distillation is that it prevents the model to be too sure about its prediction (similarly to label smoothing).
Here is an example to see this idea in practice. In language modeling, we can easily observe this uncertainty by looking at the distribution over the vocabulary. Here are the top 20 guesses by BERT for completing this famous quote from the Casablanca movie:
👯♂️ How can we copy this dark knowledge?
In the teacher-student training, we train a student network to mimic the full output distribution of the teacher network (its knowledge).
We are training the student to generalize the same way as the teacher by matching the output distribution.
Rather than training with a cross-entropy over the hard targets (one-hot encoding of the gold class), we transfer the knowledge from the teacher to the student with a cross-entropy over the soft targets (probabilities of the teacher). Our training loss thus becomes:
This loss is a richer training signal since a single example enforces much more constraint than a single hard target.
To further expose the mass of the distribution over the classes, Hinton et al. introduce a softmax-temperature:
When T → 0, the distribution becomes a Kronecker (and is equivalent to the one-hot target vector), when T →+∞, it becomes a uniform distribution. The same temperature parameter is applied both to the student and the teacher at training time, further revealing more signals for each training example. At inference, T is set to 1 and recover the standard Softmax.
🗜Hands-on coding in PyTorch — Compressing BERT
We want to compress a large language model using distilling. For distilling, we’ll use the Kullback-Leibler loss since the optimizations are equivalent:
When computing the gradients with respect to q (the student distribution) we obtain the same gradients. It allows us to leverage PyTorch implementation for faster computation:
Using the teacher signal, we are able to train a smaller language model, we call DistilBERT, from the supervision of BERT 👨👦 (we used the English
bert-base-uncased version of BERT).
Following Hinton et al., the training loss is a linear combination of the distillation loss and the masked language modeling loss. Our student is a small version of BERT in which we removed the token-type embeddings and the pooler (used for the next sentence classification task) and kept the rest of the architecture identical while reducing the numbers of layers by a factor of two.
Overall, our distilled model, DistilBERT, has about half the total number of parameters of BERT base and retains 95% of BERT’s performances on the language understanding benchmark GLUE.
❓Note 1 — Why not reducing the hidden size as well?
Reducing it from 768 to 512 would reduce the total number of parameters by ~2. However, in modern frameworks, most of the operations are highly optimized and variations on the last dimension of the tensor (hidden dimension) have a small impact on most of the operations used in the Transformer architecture (linear layers and layer normalisation). In our experiments, the number of layers was the determining factor for the inference time, more than the hidden size.
Smaller does not necessarily imply faster…
❓Note 2 — Some works on distillation like Tang et al. use the L2 distance as a distillation loss directly on downstream tasks.
Our early experiments suggested that the cross-entropy loss leads to significantly better performance in our case. We hypothesis that in a language modeling setup, the output space (vocabulary) is significantly larger than the dimension of the downstream task output space. The logits may thus compensate for each other in the L2 loss.
Training a sub-network is not only about the architecture. It is also about finding the right initialization for the sub-network to converge (see The Lottery Ticket Hypothesis for instance). We thus initialize our student, DistilBERT, from its teacher, BERT, by taking one layer out of two, leveraging the common hidden size between student and teacher.
We also used a few training tricks from the recent RoBERTa paper which showed that the way BERT is trained is crucial for its final performance. Following RoBERTa, we trained DistilBERT on very large batches leveraging gradient accumulation (up to 4000 examples per batch), with dynamic masking and removed the next sentence prediction objective.
Our training setup is voluntarily limited in terms of resources. We train DistilBERT on eight 16GB V100 GPUs for approximately three and a half days using the concatenation of Toronto Book Corpus and English Wikipedia (same data as original BERT).
The code for DistilBERT is adapted in part from Facebook XLM’s code and in part from our PyTorch version of Google AI Bert and is available in our pytorch-transformers library 👾 along with several trained and fine-tuned versions of DistilBert and the code to reproduce the training and fine-tuning.
🎢 Model performances — Testing DistilBERT
We compare the performance of DistilBERT on the development sets of the GLUE benchmark against two baselines: BERT base (DistilBERT’s teacher) and a strong non-transformer baseline from NYU: two BiLSTMs on top of ELMo. We use the jiant library from NYU for ELMo baselines and pytorch-transformers for the BERT baseline.
As shown in the following table, DistilBERT’s performances compare favorably with the baselines while having respectively about half and one third the number of parameters (more on this below). Among the 9 tasks, DistilBERT is always on par or improving over the ELMo baseline (up to 14 points of accuracy on QNLI). DistilBERT also compares surprisingly well to BERT: we are able to retain more than 95% of the performance while having 40% fewer parameters.
In terms of inference time, DistilBERT is more than 60% faster and smaller than BERT and 120% faster and smaller than ELMo+BiLSTM 🐎
To further investigate the speed-up/size trade-off of DistilBERT, we compare, in the left table, the number of parameters of each model along with the inference time needed to do a full pass on the STS-B dev set on CPU (using a batch size of 1).
🔮 Downstream task: Distillation & transfer-learning
We further study the use of DistilBERT on downstream tasks under efficient inference constraints. We use our compact pre-trained language model by fine-tuning it a classification task. A nice way to actually mix distillation pre-training and transfer-learning!
We selected the IMDB Review Sentiment Classification which is composed of 50'000 reviews in English labeled as positive or negative: 25'000 for training and 25'000 for test (and with balanced classes). We trained on a single 12GB K80.
First, we train
bert-base-uncased on our dataset. Our dear BERT 💋 reaches an accuracy of 93.46% (average of 6 runs) without any hyper-parameters search.
We then train DistilBERT, using the same hyper-parameters. The compressed model reaches an accuracy of 93.07% (average of 6runs). An absolute difference of 0.4% in performances for a 60% reduction in latency and 40% in size 🏎!
❓Note 3 — As noted by the community, you can reach comparable or better score on the IMDB benchmark with lighter methods (size-wise and inference-wise) like ULMFiT. We encourage you to compare on your own use-case! In particular, DistilBERT can give a sensible lower-bound on Bert’s performances with the advantage of faster training.
Another common application of NLP is Question Answering. We compared the results of the
bert-base-uncased version of BERT with DistilBERT on the SQuAD 1.1 dataset. On the development set, BERT reaches an F1 score of 88.5 and an EM (Exact-match) score of 81.2. We train DistilBERT on the same set of hyper-parameters and reach scores of 85.1 F1 and 76.5 EM, within 3 to 5 points of the full BERT.
We also studied whether we could add another step of distillation during the adaptation phase by finetuning DistilBERT on SQuAD using the finetuned BERT model as a teacher with a knowledge distillation loss.
Here we are finetuning by distilling a question answering model into a language model previously pre-trained with knowledge distillation! That a lot of teachers and students🎓
In this case, we were able to reach interesting performances given the size of the network: 86.2 F1 and 78.1 EM, ie. within 3 points of the full model!
Other works have also attempted to accelerate question answering models. Notably, Debajyoti Chatterjee, uploaded an interesting work on arXiv which follows a similar method for the adaptation phase on SQuAD (initializing a student from its teacher, and training a question-answering model via distillation). His experiments present similar relative performances with regards to BERT (base uncased). The main difference with our present work is that we pre-train DistilBERT with a general objective (Masked Language Modeling) in order to obtain a model that can be used for transfer-learning on a large range of tasks via finetuning (GLUE, SQuAD, classification…).
🙌 Less is more: smaller models also spark joy 🌟
We are very excited about DistilBERT’s potential. The work we’ve presented is just the beginning of what can be done and raises many questions: How far can we compress these models with knowledge distillation? Can these technics be used to get further insights into the knowledge stored in the large version? What aspects of linguistic/semantics do we lose in this type of compression?…
One essential aspect of our work at HuggingFace is open-source and knowledge sharing as you can see from our GitHub and medium pages. We think it is both the easiest and fairest way for everyone to participate and reap the fruits of the remarkable progress of deep learning for NLP.
Thus, together with this blog post, we release the code of our experiments 🎮 (in particular the code to reproduce the training and fine-tuning of DistilBERT) along with a trained version of DistilBERT in our pytorch-transformers library🔥.
Many thanks to Sam Bowman, Alex Wang and Thibault Févry for feedback and discussions!