Experiments with Knowledge Distillation of Dual Encoder Models

Dhairya Dalal
Posh Engineering
Published in
7 min readNov 2, 2020

Experiments with Knowledge Distillation of Dual Encoder Models

In this blog post, we will introduce the concept of knowledge distillation, how to implement it, and describe our experiences with it. Knowledge distillation is a model compression technique in which a smaller and more efficient model is trained from a performant large/ensemble model. We investigated knowledge distillation to see if it could solve several technical challenges related to our production intent recognition model.

Background and Motivation

At Posh, we use deep learning and pretrained language models for the natural language understanding (NLU) component in our conversational AI products (Smart IVR, FAQbot, and banking bot). The goal of NLU (specifically our intent-recognition model) is to map a user’s utterance or question to a set of known intents (e.g. bank-balance, atm-hours, etc). These intents are used by the conversational bots to generate an appropriate response to the user. You can learn more about our NLU in this blog post.

We put into production a Dual Encoder model based off of the recent ACL paper Efficient Intent Detection with Dual Sentence Encoders. Our implementation involves using sentence embeddings generated by two different language models (ConveRT and Sentence-Roberta) and feeding into a single layer MLP classifier. The ConveRT embedding in particular was incredibly effective and significantly contributed to the model’s overall accuracy.

Our implementation was a bit clunky as we are using two separate deep learning libraries (TensorFlow and PyTorch). While we are primarily a PyTorch shop, the ConveRT model was initially released as a TensorflowHub model. About a month ago, the ConveRT model was taken off the public domain by its authors. Luckily we had cached the public ConveRT weights but it had introduced a set of liabilities and technical scaling challenges. It was impossible to port ConveRT to PyTorch (the original checkpoints were not made available) without retraining from scratch on the 3.7 billion Reddit conversations used by the original implementation.

So we decided to investigate knowledge distillation as a potential solution. The anticipated benefits of knowledge distillation for us were:

  • A single model that ideally learns to emulate the information from concatenated embeddings generated by ConveRT and Sentence-Roberta
  • A student model that ideally learns the weights of ConveRT without requiring us to retrain a separate ConveRT model from scratch
  • The ability to move model entirely to PyTorch without relying on TensorflowHub

What is Knowledge Distillation?

Knowledge distillation is a model compression technique in which a student model is trained to learn the output distribution of a teacher model.The teacher model is already trained and is often a larger model or an ensemble model consisting of multiple trained models. The student learns by observing the teacher predictions in a fine-tuning task and attempts to mimic the output distributions of the teacher. Ideally in this process, knowledge from the teacher model can be transferred to the student model without requiring the larger capacity and all the prior training required by the teacher model.

In a standard training situation for multi-class classification, the loss function aims to minimize the cross entropy between the softmax of the model’s logits and the ground truth. The challenge with the ground truth representation is that the correct value is set to one and all the surrounding values are set to zero. There is no information about the classes relative to each other. The softmax probability distribution produced by a trained model not only conveys information about the best predicted class but also how divergent or closely related the predicted class was to its surrounding classes. When considered in aggregate, we can learn more about how the model understands and maps its inputs to the projected output space. Geoffrey Hinton refers to this class probability knowledge as “dark knowledge” which the student aims to learn in the knowledge distillation process.

Knowledge distillation was formalized in the paper Distilling the Knowledge in a Neural Network by Geoffrey Hinton, Oriol Vinyals and Jeff Dean. The paper introduces the concept of “softmax temperature”, where pi, probability for a class i is calculated from logits z and temperature T.

The intuition around T is that when T=1, you have a standard softmax function and as you increase T, more information about the classes the teacher found similar to the predicted class is revealed. The distillation loss function attempts to reduce the distance between the class probabilities generated by the teacher and that of the student. By back-propagating the distillation loss, some information about the underlying knowledge of the teacher is communicated back to the student.

In addition to the distillation loss, the ground truth is also considered. The student’s standard loss is calculated between the softmax of student’s logits and the ground truth. A weighting factor, alpha and beta is applied to weight student and distillation loss in the final combined loss calculation. The overall loss function can be described as: Loss = (alpha * student_loss) + (beta * distill_loss). At each step in the training process, this combined loss is calculated and back-propagated through the student model until the model converges.

Implementing Knowledge Distillation

In this section we describe the implementation of the knowledge distillation process. Implementing knowledge is relatively straight-forward. To start, you need a teacher model, which in our case was the trained Dual Encoder model, a student model (we considered several models including a baseline BiLSTM-classifier, uninitialized BERT and a limited pre-trained BERT) and a fine-tuning task, i.e. our intent-classification dataset.

The knowledge distillation training process follows a standard training process along the lines of:

For knowledge distillation, you now have two models generating prediction logits and three separate loss calculations. The student loss is the standard cross entropy loss between the student logits and the ground truth labels. The distillation loss can be calculated using KL divergence over the log softmax of the student and teacher logits normalized by the temperature values. The final loss then is the weighted sum of the student and distillation loss values. All together you have the following:

Putting it all together the knowledge distillation process can be summarized in the code snippet below. A couple of things to note. Make sure that the teacher model is set to evaluation mode as it needs to be generating fixed predictions. Alpha, beta and temp can be tuned through gridsearch. The general intuition is that beta should be large so that it weights the distillation loss which contains more information than the student loss. Usually a temp value between 1–2 is sufficient. In our experiments we found large temperature values detrimental to the distillation process but it could also be something specific to our datasets. We end up using 2 as its the values used by HuggingFace for in their distillBert implementation.

Our Experience and Conclusion

This line of investigation is still ongoing for us. In this section we’ll briefly describe what we’ve learned so far and hopefully in a future post describe the conclusions of our investigation. To prove the process end-to-end we considered the following three student models: a simple 3-layer BiLSTM classifier using sent-piece embeddings as our baseline, an uninitialized 3-layer, 6-attention head BERT model and a pretrained 3-layer/6-attention head BERT model. We were able to get the distillation proof-of-concept working and now are trying to figure out a way to train a student model that can replicate our Dual Encoder teacher model.

Our key findings were:

  • We observed the distilled models converged faster compared to the baseline model trained with standard loss. For example the BiLSTM took about 200 epochs to converge, while the distilled model only took about 50 epochs to converge
  • The BERT student model was better than the baseline BiLSTM model. The distilled pretrained BERT model was able to capture about 95% of the teacher’s performance compared to the distilled BiLSTM model which only captured about 87% of the teacher’s accuracy.
  • However, it was interesting to see that there was a stark performance distinction between using pretrained and uninitialized BERT. In the context of the pretrained model, the distance between the distilled and baseline pretrained BERT model’s performance was negligible. The distillation process did take fewer epochs to train but the overall performance was nearly identical to the baseline fine-tuned model.
  • We’re not sure if the student architecture needs to match the teacher architecture to get the full benefits of the distillation process or perhaps in this case more data or slightly higher capacity models would suffice.

Hopefully this post is helpful in describing what model distillation is and how to implement it. In a future blog post, we hope to share more of our lessons learned as we attempt to train a student model to learn both the ConveRT embedding weights and a unified student model that can replicate the performance of our Dual Encoder model.

If you’re interested in solving interesting problems like these, apply to join our team and help us build the future of conversational AI!

--

--

Dhairya Dalal
Posh Engineering

Dhairya is NLP research engineer and leads deep learning and NLP research at Posh Tech.