# A Primer on Knowledge Distillation in NLP — Part 1

In recent years, deep neural networks have achieved impressive performance on a variety of Natural Language Processing (NLP) tasks, however, high computational and storage requirements pose a great challenge to deploy them in real-world applications, especially on devices with limited resources, such as mobile phones. To this end, a variety of model compression and acceleration techniques have been developed. One such technique is **knowledge distillation** that effectively learns a small **student** model from a large **teacher** model. In this article series, we will first study the basics of knowledge distillation and then discuss its applications in NLP.

## Introduction

where **T** is a temperature that is normally set to 1. Using a higher value for **T** produces a **softer** probability distribution over classes.

In the pre-trained model, knowledge lies in the class probabilities produced by softmax of the model. All probability values including the target class probability describe relevant information about the input data. Thus, instead of a one-hot representation of the target label where only the target class is considered in cross-entropy, all probabilities over the whole classes from the pre-trained model can provide more information about the input data in cross-entropy, and can teach new models more efficiently.

All probabilities from the pre-trained model are considered as **soft target probabilities. **The student model is trained to learn the exact behavior of the teacher model by trying to replicate its outputs i.e. a way to transfer the generalization ability of the teacher model to a student model is to use the class probabilities produced by the teacher model as **soft targets** for training the student.

**Knowledge Distillation consists of two steps:**

- Train a Teacher Model (typically a large network)
- Distill the knowledge from the teacher model (information present in soft labels of the predictions) to a student model (typically a small network)

*Figure 1* illustrates KD from a *teacher* to a *student* model.

The loss function for KD is:

outputs_teacher = teacher_model(x)

outputs = student_model(x)logits = F.log_softmax(outputs, dim=1)loss = torch.nn.KLDivLoss()(F.log_softmax(outputs / T, dim=1),

F.softmax(outputs_teacher / T, dim=1)) * (alpha * T * T) + \

F.nll_loss(logits, y) * (1. - alpha)

The first loss term is the **KL Divergence** between the soft target probability obtained from the teacher’s predictions and the student’s prediction probability distribution.

The second loss term is the regular **cross-entropy** loss term used in classification problems.

Then, a weighted average of these objective functions is used.

Let’s understand this with a simple example.

- We first trained a 2-layer neural network (with
**13K**parameters) using the**60000**images of the MNIST training dataset. It achieved an accuracy of**92.51**% on the test set. We refer to this model as the**teacher model**. - We then trained a smaller model (with
**3.2K**parameters) using**800**training images only. It achieved just**63.18**% accuracy. - Finally, we trained this smaller model with
**KD**from the teacher model using the same**800**training images. It achieved**71.38**% accuracy.

This shows that soft targets can transfer a great deal of knowledge to the distilled model.

Script to Train a Model on MNIST Training Set

**Command to Train Teacher Model:**

python mnist.py — n_gpu 0 — num_workers 0 — num_labels 10 — fc1_size 16 — output_dir ./outputs/mnist/ — train_batch_size 64 — eval_batch_size 64 — num_train_epochs 3 — seed 42 — save_top_k 1 — save_last — learning_rate 5e-3 — data_dir ../data/mnist/ — max_train_samples -1

**Command to Train a Smaller Model** (Note that this is not our student model, it is just for comparing the performance of our student model):

python mnist.py — n_gpu 0 — num_workers 0 — num_labels 10 — fc1_size 4 — output_dir ./outputs/mnist_800/ — train_batch_size 64 — eval_batch_size 64 — num_train_epochs 40 — seed 42 — save_top_k 0 — learning_rate 5e-3 — data_dir ../data/mnist/ — max_train_samples 800

**Script to Train With Knowledge Distillation**

**Command to Train Student Model**:

python mnist_with_kd.py — n_gpu 0 — num_workers 0 — num_labels 10 — fc1_size 4 — output_dir ./outputs/mnist_800kd/ — train_batch_size 64 — eval_batch_size 64 — num_train_epochs 40 — seed 42 — save_top_k 0 — learning_rate 5e-3 — data_dir ../data/mnist/ — max_train_samples 800 — teacher_model ./outputs/mnist/last.ckpt — alpha_for_kd 0.8 — temperature_for_kd 20

In the next part of this article series (ETA: 4th Jan 2022), we will focus on NLP and discuss recent research papers related to Knowledge Distillation. Specifically, we will discus the following papers:

- Improving Multi-Task Deep Neural Networks via Knowledge Distillation for Natural Language Understanding
- BAM! Born-Again Multi-Task Networks for Natural Language Understanding
- Distilling Task-specific Knowledge from BERT into Simple Neural Networks
- MIXKD: TOWARDS EFFICIENT DISTILLATION OF LARGE-SCALE LANGUAGE MODELS
- Improved Knowledge Distillation via Teacher Assistant
- Patient Knowledge Distillation for BERT Model Compression
- WELL-READ STUDENTS LEARN BETTER: ON THE IMPORTANCE OF PRE-TRAINING COMPACT MODELS
- Self-Knowledge Distillation in Natural Language Processing
- Patient Knowledge Distillation for BERT Model Compression
- TinyBERT: Distilling BERT for Natural Language Understanding
- …

Check out my other articles here.

**References**

- MNIST — http://yann.lecun.com/exdb/mnist/
- Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. “Distilling the knowledge in a neural network.”
*arXiv preprint arXiv:1503.02531*(2015).