Understanding Knowledge Distillation in Simple Steps

Satya
8 min readJan 10, 2023

--

Introduction:

Running deep neural network on edge devices requires more computing power and storage space.There are many ways to optimize the neural network and run it on the edge device such as

  • Quantization: 32-bit weights are represented with 8 bit or lower.
  • Parameters Pruning: Unnecessary connections which does not influence output are removed.
  • Knowledge Distillation: Transfer learned knowledge from big teacher network to smaller student network.And this is the most effective way of optimizing the neural network where lot of research is happening.

Please check my article on Quantization https://medium.com/@satya15july_11937/network-optimization-with-quantization-8-bit-vs-1-bit-af2fd716fcae.

Let’s discuss here what is Knowledge Distillation and how to use this method to optimize the neural network so that it can efficiently run on the edge device.

What is Knowledge Distillation:

Knowledge Distillation idea was first invented by Geoffrey Hinton and his team in the paper https://arxiv.org/abs/1503.02531.The main idea was

  • Large(Deep and Wide) networks are good in solving complex computer vision task and training them on big data set requires high end system which comprises both CPU and GPU.
  • But,Inference does not need big network once you understand and solve the problem with big network which require GPU and big data set.
  • So the idea was to transfer the knowledge of the big teacher network to small student network through training.

Paper Analysis:

This paper talks about 2 important points

  1. Modified Softmax Function with Temperature:
Fig.1- Softmax Function

When the logit layer output are passed though the softmax function,it gives more value to one output which is the probable target value(in the above case it’s cat). Hinton modified the softmax function as follows:

Fig.2- Equation for Softmax with Temperature(T)

But the question is why he divided a term called Temperature(T) to softmax function.Let’s try with a small example:

import numpy as np

logits = np.array([1., 2., 3.])
logits_exp = np.exp(logits)

print("logits_exp: {}".format(logits_exp))

T = [1., 5., 7., 10.]

for t in T:
logits_exp_normalized = np.exp(logits/t)/sum(np.exp(logits/t))
print("temprature: {} : logits_exp_normalized:{}".format(t, logits_exp_normalized))

[Output]:
logits_exp: [ 2.71828183 7.3890561 20.08553692]
temprature: 1.0 : logits_exp_normalized:[0.09003057 0.24472847 0.66524096]
temprature: 5.0 : logits_exp_normalized:[0.2693075 0.32893292 0.40175958]
temprature: 7.0 : logits_exp_normalized:[0.28700357 0.33107727 0.38191915]
temprature: 10.0 : logits_exp_normalized:[0.30060961 0.33222499 0.3671654 ]

T=1(Softmax): It gives more weight to 3rd option ”[0.09003057 0.24472847 0.66524096]”.

T=2/3/4: The output is well distributed across all three output “[0.2693075 0.32893292 0.40175958]”

Fig.3- Output of Softmax function with T=5, 7, 10

As you can see from the sample,the temperature(T),with T>1,normalizes the output distribution and carries information about all the target types.This information is actually squeezed when you use normal softmax function.

Hinton’s idea was“Matching logits is a special case of Knowledge Distillation” and he used this trick as the first experiment to do knowledge distillation.

2. Soft-Labels and KL Divergence Loss: Dividing logits output with temperature(T) will give you soft labels and these soft labels carries information about all the classes that Teacher/Student model wants to solve.This Soft-Labels between Teacher and Student needs to be optimized with the help of KL(Kullback-Leibler)Divergence loss,which is the distillation loss in this case.

Fig.4- Architecture of Knowledge Distillation proposed by Hinton

Drawback:

Though Student network is much more faster than Teacher network,but there is an accuracy gap between Teacher and Student network.This is because it only consider the output layer of both Teacher and Student,but the knowledge of intermediate layers,which basically carries high to low feature information, are not considered or transferred to Student network.

To overcome this problem there was another paper https://arxiv.org/pdf/1412.6550 , which came from Romero and his team to solve the above issue.Their idea was

  • If you consider only output layer of both Teacher and student layer,then it does not consider the intermediate layer of Teacher.So if intermediate layer information can be passed to Student Network, then the accuracy gap between Teacher and Student Network can further be optimized.
  • Then, How to Transfer the Intermediate Layer information to Student?
    Romero and his find a novel technique.
    - Pick only the middle layer from Teacher,called as Hint Layer and then transfer the knowledge to middle layer of Student called as Guided Layer.
    - So essentially, you use a small convolution layer which connects between Hint Layer and Guided Layer and tries to learn the weights from Teacher to Student.Please check the below diagram.

This paper demonstrated that transferring of Output layer and Middle layer information from Teacher to Student made the Student Network perform better both in terms of accuracy and speed against Teacher Network.This method only transfer the middle layer information to Student,not all.

Fig.6- Data from https://arxiv.org/pdf/1412.6550

This 2 papers ( https://arxiv.org/abs/1503.02531 and https://arxiv.org/pdf/1412.6550) gave an insight on what is knowledge Distillation and pave the way for further exploration in this field. Understanding these 2 papers are crucial in your journey to Knowledge Distillation and that’s why i gave a detailed analysis.

Here are some other knowledge distillation methods which were published after these 2 papers and showed remarkable performance gain.

Fig.7-Different Methods for Knowledge Distillation

Transfer Learning vs Transfer Knowledge(Knowledge Distillation):

Fig.8- Transfer Learning vs Transfer Knowledge(Knowledge Distillation)
  • [Transfer Learning]:
    Let’s say, a Model-A is trained on Cat and Dog (2 classes) dataset.This model can be extended to solve cat, dog,Bear and Horse classes by doing following modifications:
    1. Use the same architecture. Only remove the head part and attach
    additional layers to solve 4 classes.(Note: Mode-A was trained to solve 2 classes,but Model-B extended to solve 4 classes)
    2. During training, you only need to train the additional layers to solve
    the 4 class problem.(You can even train the backbone after
    training the additional layers)
  • [Transfer Knowledge(Knowledge Distillation)]:
    Transfer Knowledge is totally different from Transfer learning.And the differences are :
    1. If Teacher network was trained to solve only 2 classes(Cat and Dog),then Student Network can only solve for 2 classes and it can not be extended to solve for 4 classes as it is the case in Transfer Learning.
    2. Two networks(Teacher & Student) are totally different, so you can not
    directly share the weights of Teacher with Student.
    3. There are different strategy to train Teacher and Student Network in
    order to transfer knowledge from Teacher to Student network.

Knowledge Distillation In Detail:

Fig.9-Knowledge Distillation Overview

Please take a look at the above diagram,this gives a clear picture about Knowledge distillation.

Let’s understand one by one(from left to right).

Types of Knowledge :

Fig.10- Types of Knowledge

In Response based knowledge distillation,only output layer is used and hence it’s not the effective way,but it can be used with other knowledge to improve performance.

In Feature based, intermediate information is used and is very effective.In Relation based KD, relationship between intermediate layers are captured and transfer those knowledge to Student network.

Fig.11-Difference between Response, Feature and Relation based knowledge

How Teacher and Student Network can be Trained(Distillation Scheme):

Here are the different ways which you can use for training the Teacher and student network.

Fig.12- Distillation Schemes

How to transfer knowledge from Teacher to Student Network(Distillation Algorithms):

Please refer https://arxiv.org/abs/2006.05525 to know more about Distillation algorithms.

Implementation:

I have implemented the above scenario and my code is shared at https://github.com/satya15july/knowledge_distillation.

Evaluation:

Here is the data:

As you can see from the above data,the Student model is 12 times faster than Teacher model and after knowledge distillation the accuracy of Student model improved by 4%. However, the student model can not achieve the same accuracy as Teacher model(which is 59%).

Accuracy of Different Distillation Methods on CIFAR-100 dataset. Data taken from https://arxiv.org/abs/1910.10699

Please check the above data which shows that Student model performs better than large Teacher model after applying knowledge distillation methods such as AT,AB, SP etc. Why our Student model implementation for semantic segmentation is not giving the same accuracy as Teacher model.So what is the difference?

You have to understand that these methods are developed for classification task and not designed for Semantic Segmentation task.That’s why Student model is still lagging behind Teacher model in terms of accuracy.

Let’s focus on task specific knowledge distillation task.

Task Specific Knowledge Distillation:

All the knowledge distillation methods proposed in Fig.7 were tried on classification problem.

But the question is can we use those methods for other tasks such as object detection, semantic segmentation etc.The answer is yes, but there will definitely be an accuracy gap between Teacher and Student.

Then how to fill this accuracy gap between Teacher and Student?The answer is, you need to understand the task specific knowledge present in Teacher and need to transfer it to Student Network.Here are some example:

Distilling Object Detector for Feature Richness (https://arxiv.org/abs/2111.00674)
Inter-class Feature Variation Distillation for Semantic Segmentation(https://arxiv.org/abs/2205.03650)

Giving detail on how semantic segmentation task can be improved further is beyond the scope of this article.Will discuss this in some other article in order to make this short.

Conclusion:

I tried to give an overview of what is knowledge Distillation and how the journey starts with Hinton and Romero Papers.Then how distillation can be improvised with task specific problem such as Object detection and Semantic Segmentation.

I hope this article will help you in understanding Knowledge Distillation.Please do not forget to subscribe to my medium channel.Thanks for reading.

References:

Reach me at

--

--

Satya

Interested in Computer Vision (2D/3D)and Deep Learning(2D/3D).Likes to write about it.