BERT Distillation with Catalyst

How to distill BERT with Catalyst.

Nikita Balagansky
PyTorch

--

Intro

In the last few years, NLP models have made a big leap in most machine learning tasks. BERT or BERT-based models are the most popular NLP models currently. The typical pipeline has changed a lot. Instead of training models from scratch, we can fine-tune the pre-trained model for our specific task and get about SOTA results. But there is one big problem, even in the simple BERT-based model. There are about 110 million parameters which could decrease inference speed. It is not a problem at all if you have one dataset to predict, but if you need your model to process the request quickly and send the response online or even want to integrate your model into mobile app model size could be a critical factor for you. But what if we could get our model smaller without losing quality?

Three basic technics allow us to do so:

  1. Quantization
  2. Pruning
  3. Distillation.

In this blog post, we will focus on distillation.

Main Idea

BERT distillation schema

Loss function

BERT base model has 12 multi-head attention encoders in it, let’s take just 6 and then try to transfer the knowledge from the big model (teacher) to small (student). We will train our small on the masked LM task but will also use our big model. Overall logic is simple: if we take a look on first five most probable words in masked position according to teacher model, all of this words will make sense and all of them could be in this position. So we need our student model to output the same distribution of word probabilities on masked tokens as a teacher, and we can add the distances between these distributions to our loss functions. Non-symmetrical distance between distribution can be defined with Kullback–Leibler divergence

KL divergence loss

Also, we can add to the loss cosine distances between hidden states.

Cosine loss

The final loss is:

Initialization

Initialization has a significant impact on the final result. With random initialization, the distillation process will take plenty more time and possibly will not converge to an optimum minimum. In the original paper were proposed to initialize student model layers with [0, 2, 4, 7, 9, 11] layers of the teacher model. But you can also try for example the first six layers.

Catalyst

We are going to use Catalyst for implementing the network.

Catalyst is a high-level framework for PyTorch deep learning research and development. It is focused on reproducibility, fast experimentation and code re-use. Here is a minimal example:

Catalyst CustomRunner example

As you can see, most of the code is to setup the model, criterion and loaders. Main driver code is handled by Catalyst. No need of any loops and you can focus on key parts of your Deep Learning research!

More detailed post about framework you can read here:

Nevertheless, we need something much more interesting than typical supervised learning pipeline, isn’t it? We need to implement teacher-student distillation. How could we do it with Catalyst? Let’ check that out!

If you don’t have time to read this article through, you can directly go to my GitHub repository, clone it, set up for it, run it.

Implementation

First of all we need to implement dataset for our data. The only difference with classical torch dataset is that data loader in catalyst should return dict. Our data loader returns this one:

I’m not going to focus on the implementation of this dataset because it’s pretty standard.

The next thing is the model implementation. Let’s take a look at the student model. As I mentioned initialization meters, so there is an extract method which produces student model state dict. Basically, it just copies specified layers from teacher to student. The next thing is to put extracted dict to our model.

This is a simplified version of what’s happening in the student model init method.

Moving next to the runner. The main logic is stored in handle_batch method.

Pretty simple, isn’t it? We just provide all model outputs into self.output.

Callbacks is all you need!

In catalyst you can use callbacks to describe all custom losses, metrics, logging and other actions.

Talking about distillation we need to implement custom losses. For example masked language model loss:

I also implemented cosine, KL divergence, MSE losses and perplexity metric callback.

Now we are all set, but what’s next?

Notebook API

Let’s take a look on catalyst notebook API.

Yes, it’s that easy! We can also add any callback we want by adding one line of code. Logging, gradient accumulation, scheduler, saving model checkpoint, tracing and more.

Config API

Catalyst also has convenient config API which can be used for more production-ready reproducible pipelines. We need to implement Experiment class to use it (you can find it in my repository). Let’s take a look on the same pipeline but with config API:

Now you can focus on your experiment and don’t mind about implementation. You almost can’t lose your trained model with catalyst, because it saves all your code, config and model state dict in specified logdir.

We can run it with:

catalyst-dl run -C config.yml

Or we can add distributed training:

catalyst-dl run -C config.yml --distributed

What about mixed precision?

catalyst-dl run -C config.yml --distributed --apex

This out-of-box functionality coming with framework.

Experiments & Results

You can take datasets of any language on your choice. You should export it in csv data format and then you can use my script.

I ran training for 100 hours on 4x1080TI GPU and trained it on lenta Russian news dataset. I took DeepPavlov’s RuBERT-base-cased as a teacher model.

Then I tried my catalyst model on Russian sentiment twitter dataset. I also compare my model with multilingual DistilBERT and teacher. I noticed that my model converges a little bit faster then multilingual. EMA smoothing is applied.

I ran three times with various random seed for every model. Here are my results.

Future Work

  • Add more losses
  • More datasets and various languages
  • How about to add quantization?

Contribution is welcome 😉

References

Authors: Nikita Balagansky

Source code:

Catalyst:

Paper by Hugging Face:

DeepPavlov:

Thank you!

--

--