A Primer on Knowledge Distillation in NLP — Part 1

Neeraj Varshney
Analytics Vidhya
Published in
4 min readDec 22, 2021

In recent years, deep neural networks have achieved impressive performance on a variety of Natural Language Processing (NLP) tasks, however, high computational and storage requirements pose a great challenge to deploy them in real-world applications, especially on devices with limited resources, such as mobile phones. To this end, a variety of model compression and acceleration techniques have been developed. One such technique is knowledge distillation that effectively learns a small student model from a large teacher model. In this article series, we will first study the basics of knowledge distillation and then discuss its applications in NLP.

Figure 1: In Knowledge Distillation, the student model learns from both the soft labels of the teacher and the true hard labels of the dataset.

Introduction

where T is a temperature that is normally set to 1. Using a higher value for T produces a softer probability distribution over classes.

In the pre-trained model, knowledge lies in the class probabilities produced by softmax of the model. All probability values including the target class probability describe relevant information about the input data. Thus, instead of a one-hot representation of the target label where only the target class is considered in cross-entropy, all probabilities over the whole classes from the pre-trained model can provide more information about the input data in cross-entropy, and can teach new models more efficiently.

All probabilities from the pre-trained model are considered as soft target probabilities. The student model is trained to learn the exact behavior of the teacher model by trying to replicate its outputs i.e. a way to transfer the generalization ability of the teacher model to a student model is to use the class probabilities produced by the teacher model as soft targets for training the student.

Knowledge Distillation consists of two steps:

  1. Train a Teacher Model (typically a large network)
  2. Distill the knowledge from the teacher model (information present in soft labels of the predictions) to a student model (typically a small network)

Figure 1 illustrates KD from a teacher to a student model.

The loss function for KD is:

outputs_teacher = teacher_model(x)
outputs = student_model(x)
logits = F.log_softmax(outputs, dim=1)loss = torch.nn.KLDivLoss()(F.log_softmax(outputs / T, dim=1),
F.softmax(outputs_teacher / T, dim=1)) * (alpha * T * T) + \
F.nll_loss(logits, y) * (1. - alpha)

The first loss term is the KL Divergence between the soft target probability obtained from the teacher’s predictions and the student’s prediction probability distribution.

The second loss term is the regular cross-entropy loss term used in classification problems.

Then, a weighted average of these objective functions is used.

Let’s understand this with a simple example.

  • We first trained a 2-layer neural network (with 13K parameters) using the 60000 images of the MNIST training dataset. It achieved an accuracy of 92.51% on the test set. We refer to this model as the teacher model.
  • We then trained a smaller model (with 3.2K parameters) using 800 training images only. It achieved just 63.18% accuracy.
  • Finally, we trained this smaller model with KD from the teacher model using the same 800 training images. It achieved 71.38% accuracy.

This shows that soft targets can transfer a great deal of knowledge to the distilled model.

Script to Train a Model on MNIST Training Set

Command to Train Teacher Model:

python mnist.py — n_gpu 0 — num_workers 0 — num_labels 10 — fc1_size 16 — output_dir ./outputs/mnist/ — train_batch_size 64 — eval_batch_size 64 — num_train_epochs 3 — seed 42 — save_top_k 1 — save_last — learning_rate 5e-3 — data_dir ../data/mnist/ — max_train_samples -1

Command to Train a Smaller Model (Note that this is not our student model, it is just for comparing the performance of our student model):

python mnist.py — n_gpu 0 — num_workers 0 — num_labels 10 — fc1_size 4 — output_dir ./outputs/mnist_800/ — train_batch_size 64 — eval_batch_size 64 — num_train_epochs 40 — seed 42 — save_top_k 0 — learning_rate 5e-3 — data_dir ../data/mnist/ — max_train_samples 800

Script to Train With Knowledge Distillation

Command to Train Student Model:

python mnist_with_kd.py — n_gpu 0 — num_workers 0 — num_labels 10 — fc1_size 4 — output_dir ./outputs/mnist_800kd/ — train_batch_size 64 — eval_batch_size 64 — num_train_epochs 40 — seed 42 — save_top_k 0 — learning_rate 5e-3 — data_dir ../data/mnist/ — max_train_samples 800 — teacher_model ./outputs/mnist/last.ckpt — alpha_for_kd 0.8 — temperature_for_kd 20

In the next part of this article series (ETA: 4th Jan 2022), we will focus on NLP and discuss recent research papers related to Knowledge Distillation. Specifically, we will discus the following papers:

  • Improving Multi-Task Deep Neural Networks via Knowledge Distillation for Natural Language Understanding
  • BAM! Born-Again Multi-Task Networks for Natural Language Understanding
  • Distilling Task-specific Knowledge from BERT into Simple Neural Networks
  • MIXKD: TOWARDS EFFICIENT DISTILLATION OF LARGE-SCALE LANGUAGE MODELS
  • Improved Knowledge Distillation via Teacher Assistant
  • Patient Knowledge Distillation for BERT Model Compression
  • WELL-READ STUDENTS LEARN BETTER: ON THE IMPORTANCE OF PRE-TRAINING COMPACT MODELS
  • Self-Knowledge Distillation in Natural Language Processing
  • Patient Knowledge Distillation for BERT Model Compression
  • TinyBERT: Distilling BERT for Natural Language Understanding

Check out my other articles here.

References

  • MNIST — http://yann.lecun.com/exdb/mnist/
  • Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. “Distilling the knowledge in a neural network.” arXiv preprint arXiv:1503.02531 (2015).

--

--

Neeraj Varshney
Analytics Vidhya

Looking for full-time positions | Ph.D. Candidate working in Natural Language Processing (https://nrjvarshney.github.io)