Understanding BERT Variants: Part 2

DistilBERT & TinyBERT using knowledge distillation

Mehul Gupta
Data Science in your pocket

--

The past 3 blogs have been a rollercoaster where we first dived in Transformers, then to BERT & then a few prominent BERT variants. Continuing the trend, we will be discussing the concept of Knowledge Distillation & how this lead to the inception of some more BERT variants.

Note: You may wish to refer to the past blogs in this series for a better understanding

As discussed last time, there exists a few evident issues with BERT majorly it being very bulky & difficult to train & deploy due to numerous parameters(110 Million to be precise). Hence, to prepare a lighter version with negligible change in performance, amongst many new ideas, the concept of Knowledge Distillation came in.

What is Knowledge Distillation?

It is a model compression technique where a small model is trained to replicate a bigger network’s behavior. The bigger network is called the Teacher Network while the smaller one trained using it is Student Network.

The Teacher Network

Assume we have a network that can predict the next token in an input sequence

Do observe a few points above

  • ‘Homework’ has the highest probability of the missing token hence the aptest token
  • But, the model was also able to conclude that ‘Book’ or ‘Assignment’ is more relevant than ‘Car’ or ‘Cake ’(higher probabilities )

This is what is called dark knowledge. During our knowledge transfer from teacher to student, we wish to learn this dark knowledge. But, usually, we may train models that produce probabilities close to 1 for the best token like

Extracting dark knowledge is tough here as apart from ‘Homework’, probabilities for all other tokens are ~0.

So, how to extract dark knowledge in such cases?

SoftMax Temperature

If you notice closely the above diagram, we have used the SoftMax function to generate a probability distribution for predicted tokens. Now, to extract dark knowledge, we will replace it with a ‘SoftMax temperature’ function. One can compare SoftMax with SoftMax Temperature easily with the below image.

SoftMax Function (Left), SoftMax with Temperature (Right)

The constant T introduced is called Temperature helping in smoothing the probability distribution. A SoftMax function can be taken as SoftMax temperature with T=1. The bigger the value for T, the smoother the distribution. Let us see how probability distribution gets affected by different T values.

So, the idea is clear. We wish to transfer dark knowledge (generated using SoftMax temperature) while training Teacher Network to Student Network. Fine, but how this transfer will take place? let’s talk about that

The Student Network

So, for now, we have a pre-trained ‘Next token detector’ Teacher Network with SoftMax Temperature. W will next walk through how this Student Network learns.

So, let ‘A’ be the Student Network & ‘B’ be the Teacher Network such that the size of A<<Size of B & B is pretrained.

Now, the probability distribution output by model B (teacher) is considered as a target(a sort of ground truth) called Soft Target & the prediction is done by the Student model called Soft Prediction. The below picture will clear out many things

The pipeline for training Student Network goes something like this

  • The input sequence is fed to Teacher Network & ‘Soft Target’ (probability distribution for different tokens) is calculated
  • The same Input sequence is fed to Student Network & a soft prediction is calculated

The loss function has majorly 2 parts

  • Distillation loss: The cross-entropy loss between ‘Soft target’ & ‘Soft prediction’
  • Student loss: A couple more concepts are required

A couple of more concepts are required:

  1. Hard Target: Converting the ‘Soft target’ probability distribution into a one-hot-encoder sort of vector by setting the highest probability as 1 & the rest as 0
  2. Hard Prediction: This changes a bit. If you remember, even while calculating Soft Prediction, we used Temperature. But in Hard Prediction, we keep T=1 & calculate the distribution of probability across all tokens

Hence, in Distillation loss, T>1 but in Student loss, T=1.

With this, let’s go back on track & see what different Knowledge Distillation-based variants BERT has !!

DistilBERT

As the name speaks for itself, we would be preparing a Student BERT using a pre-trained Teacher BERT (BERT-Base) following Knowledge Distillation. What are the major gains we get

  • It's 60% faster
  • & 40% smaller in size reducing total parameters from 110 Million to 66 Million
  • A few layers are reduced helping in faster training. Though, the output embedding dimension (d_model) remains the same i.e. 768
  • And with so many optimizations, the results produced were ~97% of the original BERT

So how does it work?

A picture would be enough !!

Here,

  • Teacher BERT is BERT-Base (remember my previous blog on BERT) & is pre-trained on MLM tasks (recognizing masked tokens)
  • The Student BERT has some layers less than the Teacher & hasn’t been pre-trained at all. Its the 1st time it is seeing any data
  • Distillation loss & Student loss discussed above are pretty clear for DistilBERT from the image itself

Though, one more loss function is added which is the cosine similarity between the Soft Target & Soft Prediction embeddings. So this was more or less about DistilBERT.

Moving on to…

TinyBERT

As we saw in DistilBERT, we were able to teach our Student network using the output embedding produced by a pre-trained Teacher Network.

Can we make our students learn from each layer learning in the Teacher Network? like the attention module of Student learning from the attention module of the Teacher & likewise for other layers separately.

A definite yes

This is what TinyBERT does

But how this be useful?

This can be very useful as the Student would be able to mimic the Teacher Network behavior helping it to learn linguistic information from the Teacher which DistilBERT may not be able to learn properly when we just try mimicking the output

Let us dissect a Teacher BERT

Do observe the above dissection & know how TinyBERT differs from DistilBERT

  • In DistilBERT, we are using just the Logits produced for making the Student learn (the top of the Network)
  • In TinyBERT, all the sections of the Teacher, be it the Attention matrix, Embeddings learned from the Embedding layer, etc. are used for training the Student

Let’s bring in another diagram

As it's quite evident, TinyBERT has ‘M’ =4 number of Encoders (the block with Attention, Normalization, FFN & Normalization stacked together) whereas the Teacher BERT has ’N’ =12 such encoders. Also, in TinyBERT, even d_model (embedding size input in the encoder) is reduced to 312 from 768 in Teacher BERT (BERT-Base). And this is why it is compact in size compared to BERT bringing total parameters to be trained to as low as 14.5 Million from 110 Million.

So, the knowledge distillation process happens at 3 levels:

  • Transformer layer distillation
  • Embedding layer distillation
  • Prediction layer distillation

Starting off with

Transformer Layer distillation

This distillation involves 2 distillation

  • Attention-based distillation (Output produced by N-Head Attention)
  • Hidden state-based attention

The attention-based distillation involves learning the attention matrix produced by the N-headed Attention Layer. We do this by minimizing the loss between N-Head Attention (Teacher) & N-Head Attention (Student) using the below loss function

Where:

  • h = number of heads
  • Aˢ = N-head Attention, Student
  • Aᵗ = N-head Attention, Teacher
  • MSE= Mean Squared Error
Pictorial Representation of how Attention-based Distillation works

Similarly, Hidden State distillation takes place where we try to reduce the MSE loss between Hidden State (Student) & Hidden State (Teacher) outputs from the N & M Encoder stack in Student & Teacher respectively

As I remember, the embedding dimension for Teacher (768) & Student (312) differs. So,

is calculating MSE as straightforward as we think?

Yeah, that’s a complication but we do have an easy solution for it i.e. Matrix Factorization. As explained in my previous post on BERT, while minimizing MSE loss between the two hidden states, we will be multiplying a weights matrix such that:

Hidden(S)[N X 312] x Weight[312x 768] = Hidden(S)[Nx768] which is in same dimension as Hidden(T). We will be learning this weights matrix as well. This will be similar for all the following distillations.

Embedding Layer Distillation

In Embedding Layer Distillation, we try to learn from Embedding Layer (Teacher) present just before the stacked ’N’ Encoder. It is pretty similar to what we have done above i.e. reduce MSE loss between the layers & use Matrix Factorization for coping with different embedding sizes.

Prediction Layer Distillation

In Prediction Layer Distillation, as we did in DistilBERT, the Student learns from the overall output of the Teacher but with a difference. As we were reducing MSE loss between Student & Teacher in the prior two distillations (Transformer & Embedding), this isn’t the case here. We will be reducing cross-entropy loss between the soft-target & soft prediction using the below formula

And with this, I am done for now.

For a more detailed explanation, do refer to Getting started with BERT by packt publication.

--

--