Knowledge Distillation in a Deep Neural Network

Distilling knowledge from a Teacher to a Student in a Deep Neural Network

Renu Khandelwal
Mar 6 · 8 min read
Photo by Mikael Kristenson on Unsplash

In this article, you will learn.

  • An easy to understand explanation of Teacher-Student knowledge distillation neural networks
  • Benefits of Knowledge distillation
  • Implementation of Knowledge distillation on the CIFAR-10 dataset

Large deep neural networks or an ensemble of deep neural networks built for better accuracy are computationally expensive, resulting in longer inference times which may not be ideal for near-real-time inference required at the Edge.

The training requirements for a model are different from requirements at the inference. During training, an Object recognition model must extract structure from very large, highly redundant datasets but, during inference, must operate in real-time with stringent requirements on latency and computational resources.

To address the latency, accuracy, and computational needs at the inference time, we can use model compression techniques like

  • Model Quantization, low-precision arithmetic for inference, like converting a float to an unsigned int.
  • Pruning, removing weights or activations that are close to zero resulting in a smaller model
  • Knowledge Distillation where a large complex model(teacher) distills its knowledge and passes it to train a smaller network(student) to match the output. The student network is trained to match the larger network's prediction and the distribution of the teacher's network.

Knowledge Distillation is a model-agnostic technique to compresses and transfers the knowledge from a computationally expensive large deep neural network(Teacher) to a single smaller neural work(Student) with better inference efficiency

Image by author

Knowledge Distillation consists of two neural networks: Teacher and Student models.

  1. Teacher Model: A larger cumbersome model can be an ensemble of separately trained models or a single very large model trained with a very strong regularizer such as dropout. The cumbersome model or the Teacher model is trained first.
  2. Student Model: A smaller model which will use the distilled knowledge from the Teacher network. It uses a different kind of training, referred to as “distillation,” to transfer the knowledge learned from the cumbersome model to a smaller Student model. The student model is more suitable for deployment as it will be computationally inexpensive with the same or better accuracy than the Teacher model.

Knoweldeg distillation extracts the knowledge from the large cumbersome Teacher model and passes it on to the smaller Student model.

How is knowledge distilled from the Teacher to the Student model?

Knowledge is distilled from large, complex Teacher model to Student model using Distillation loss.

Knowledge distillation uses generalization using soft targets that mitigate the over-confidence issue of neural networks and improves model calibration.

Distillation loss uses the soft targets to minimize the squared difference between the logits produced by the cumbersome model and the logits produced by the small model.

In knowledge distillation, knowledge is transferred to the distilled model by training the cumbersome model with a high temperature in its softmax to generate soft target distribution. The same high temperature is used for training the distilled model, but after it has been trained, it uses a temperature of 1.

Knowledge distillation minimizes the KL divergence between a teacher and student network's probabilistic outputs in the distilled model. KL divergence constraints the student model's outputs to match soft targets of the large, cumbersome model.

What are Soft targets and Hard targets?

Hard targets are generated when using a softmax function. Using the softmax function, the model almost always produces the correct answer with very high confidence, which has very little influence on the cross-entropy cost function during the transfer of knowledge from Teacher to Student because the probabilities are so close to zero.

Soft targets use the logits, the inputs to the final softmax rather than the softmax's probabilities as the targets for learning the small model.

When the soft targets have high entropy, they provide much more information per training case than hard targets. They also have less variance in the gradient between training cases.

Hard Targets are generated using the Softmax activation function that converts the logit, zi, computed for each class into a probability,

Soft targets

Soft targets are computed as qi, by comparing zi with the other logits where T is a temperature.

When T=1, it is the same as the softmax activation function. A higher value for T produces a softer probability distribution over classes, as shown below.

  • Soft targets contain valuable information on the rich similarity structure over the data, i. e. it says which 2 looks like 3’s and which looks like 7’s.
  • Provides better generalization and less variance in gradients between training examples.
  • Allows the smaller Student model to be trained on much smaller data than the original cumbersome model and with a much higher learning rate
  • Can be deployed at the Edge: Knowledge distillation, a student model, inherits better quality from the teacher and is more efficient for inference due to its compactness needing less computational resources.
  • Improves generalization: The teacher produces “soft targets” for training the student model. Soft targets have high entropy providing more information than one hot encoded hard target. Soft targets have less variance in the gradient between training cases allowing student networks to be trained on much less data than the Teacher using a much higher learning rate.

Import required libraries

Creating the CIFAR-10 train and test dataset

Create the Cumbersome Teacher model

Cumbersome Teacher Model

Create the simple Student model

Simple Student Model

You can see that our Teacher Model has approximately 1.51M parameters and will be computationally very expensive when compares to the smaller, simpler Student Model with just 313K parameters.

Train the Teacher Model

Teacher model Training and Evaluation

Distill the Knowledge from the Teacher Model to the Student Model

Create a Distiller class to distill the knowledge using the student and distillation loss.

  • The student loss function is the difference between student predictions and ground-truth using the softmax function where T=1.
  • The distillation loss function is the difference between the soft student predictions and the soft teacher labels.

Code Adapted and Inspired by: //keras.io/examples/vision/knowledge_distillation/

In the above code, we override the compile, train_step, and test_step of the Model class.

We perform a forward pass on both the Teacher and Student model during training and calculate the loss as shown below, and then perform a backward pass.

where alpha is a factor to weigh the student and distillation loss.

calculate the gradients for the student weights as only the student weights are updated

During the test step, we evaluate the student model using the student prediction and the ground truth.

Compiling the distiller

The Distillation model uses distillation loss minimized using the KL divergence between the probabilistic outputs of a teacher and student network in the distilled model.

Distill the knowledge from the Teacher model to Student and Evaluate the distiller

We can see that the Teacher model, on average, tasks 12 sec per epoch to train, whereas the distiller takes 9sec on average per epoch to train.

The Accuracy of the Teacher model with 5 epochs is 68.8%, whereas the accuracy of the Distillation Model 71.1%

Conclusion:

Knowledge distillation is a model agnostic compression technique that extracts the knowledge from the large cumbersome Teacher model and passes it on to the smaller Student model. The Knowledge distillation model uses soft targets and has less training and inference time but a higher accuracy than the large cumbersome Teacher Model.

Distilling the Knowledge in a Neural Network: Geoffrey Hinton, Oriol Vinyal, and Jeff Dean

https://keras.io/examples/vision/knowledge_distillation/

https://www.youtube.com/watch?v=k63qGsH1jLo&t=258s

Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data…

Sign up for Analytics Vidhya News Bytes

By Analytics Vidhya

Latest news from Analytics Vidhya on our Hackathons and some of our best articles! Take a look.

By signing up, you will create a Medium account if you don’t already have one. Review our Privacy Policy for more information about our privacy practices.

Check your inbox
Medium sent you an email at to complete your subscription.

Renu Khandelwal

Written by

Loves learning, sharing, and discovering myself. Passionate about Machine Learning and Deep Learning

Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Renu Khandelwal

Written by

Loves learning, sharing, and discovering myself. Passionate about Machine Learning and Deep Learning

Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store