Knowledge Distillation for Object Detection 1: Start from simple classification model

Seungki Kim
Analytics Vidhya
Published in
4 min readMay 22, 2020
Concept of Knowledge Distillation

Motivation

Knowledge Distillation (KD) is a technique for improving accuracy of a small network (student), by transferring distilled knowledge produced by a large network (teacher). We can also say that KD is for compressing model (teacher → student) with minimal accuracy loss.

My final goal for this series is to apply this technique to my light-weight object detection model. In this story, as a first step, I will implement a simple classification model, and test the power of KD.

The reference paper about this topic is this: Hinton et. al. “Distilling the Knowledge in a Neural Network”. NIPS2014. ([1]) They introduced basic concept of the knowledge distillation, and showed that the classification accuracy improves by applying it. I’m going to re-implement the experiment in this paper and check if it actually works well.

Distilling Knowledge for Classifier: Soft Targets

In this paper, they showed that the knowledge about “how to generalize”, can be transferred from teacher to student, by soft targets. Compared to conventional hard targets, soft targets has scores for all classes.

*Reference: https://www.ttic.edu/dl/dark14.pdf ([3])

Loss for Soft Targets

For training a model with soft targets, they modified conventional loss function with a new softmax formula.

Softmax with temperature

In higher temperature(T) settings, the class scores change more smoothly, and larger amount of knowledge will be transferred. In this paper, they choose 20 for the temperature value.

Softmax for binary classificaion, (Red: T=1, Blue: T=20)

To train student with both soft label and hard label, they used weighted sum of two losses. The first loss is conventional cross-entropy loss between predictions and the given hard labels. The second loss is cross-entropy loss between predictions and the given soft labels, with higher temperature for softmax. The weights for the first loss is 1, and the second loss is T² (the reason for this is explained in the paper).

Experiment: Does it really improves the accuracy?

This is the main part of this story. I wrote codes for reproducing the results of the paper [1]. You can find the full-codes written for this experiment: https://github.com/poperson1205/knowledge_distillation

1. Implement a large (teacher) and a small (student) network for classification

  • Teacher:
    784 → ReLU → 1200 → ReLU → 1200 → 10
    (dropout 20% of input, and 80% of the two hidden layer’s output)
  • Student:
    784 → ReLU → 800 → ReLU → 800 → 10

2. Train networks with MNIST dataset (teacher and student independently)

Basically, I followed the training settings described in this paper [2] (see appendix A).

The main differences in training teacher and student are as follows.

  • Teacher: Jitter input image, and constraint weight norm to 15.0
  • Student: Vanilla backpropagation

3. Distill knowledge from teacher to student

As written in the earlier part of this story, I changed the loss term for training student network, and gave soft label generated from teacher network.

4. Evaluate networks (teacher, student, student+distillation)

The resultant error rates for each networks are as follows.

  • Teacher: 100 / 10000 (1.00%)
  • Student: 171 / 10000 (1.71%)
  • Student with KD: 111 / 10000 (1.11%)

We can see that error rate of Student is decreased significantly (1.71% → 1.11%), by applying KD.

Conclusion

It was quite impressive to watch the results of knowledge distillation through my own code. Now I believe it really works!… My next step will be implementing simple object detection network. If you have any question about this implementation, please feel free to ask :)

You can get the full codes for this experiment here: https://github.com/poperson1205/knowledge_distillation

References

[1] Hinton et. al. “Distilling the Knowledge in a Neural Network”. NIPS2014.

[2] Hinton et. al. “Improving neural networks by preventing co-adaption of feature detectors”. https://arxiv.org/abs/1207.0580

[3] Presentation material of paper[1]: https://www.ttic.edu/dl/dark14.pdf

--

--

Seungki Kim
Analytics Vidhya

Computer Graphics, Computer Vision, System Trading