Butterflies start by being a caterpillar, which is their larval form. They extract necessary energy and nutrients from the environment before obtaining the adult form, which is optimized for very different requirements. It’s similar for neural networks as well. The purposes that a neural network serves during training and deployment are quite different. But we use the same network all the time. It doesn’t have to be like this. Inspired by this idea and some prior research done by Caruana et al., Geoffrey Hinton, Oriol Vinyals and Jeff Dean set on a journey to unleash the dark knowledge present in the neural network.
What is Dark Knowledge?
Let’s consider a simple problem of multi class classification. We have a bunch of images and we want to classify them into one of the four classes
car. The neural network that we train to solve this problem will obviously have softmax as the final layer.
Let’s say that the image we are trying to classify is of a
dog. So the true labels will be
Let’s say the probabilities computed for this image by a normal softmax function are
We can see that the probabilities for
car are negligible. But what if we soften these probabilities. To do so, we can change softmax function to look like this.
softmax = exp(zᵢ / T) / Σ exp(z / T), where T is temperature and z is logits.
When T = 1, it is a normal softmax function. Higher the values of T, more softer the probabilities. With the new softmax function, our output changes to
The new probabilities calculated by the modified softmax function reveal a lot more information than the normal probabilities. Looking at the new output, we can say that an image of a dog has a very small chance of being mistaken as a
cow. But this mistake is an order of magnitude more probable to occur than the mistake of misclassifying the image as a
The relative probabilities of the incorrect outputs tell us a lot about how the model tend to generalize. — https://arxiv.org/pdf/1503.02531.pdf
This is the dark knowledge! Softened probabilities reveal this dark (hidden) knowledge learnt by the model. Extracting and using this dark knowledge is also known as distillation. We will revisit this concept again in the next section with a real world example.
How to use Dark Knowledge?
We can train a large model with the actual targets. Once, this large model is trained, we can extract the dark knowledge hidden in this network. We can then train a much smaller network on the dark knowledge. This is the high level idea. Let’s try to understand this in more detail by writing some code.
Running the code
You can run the code shown below in your browser through this Colab Notebook.
Update (Sept 2019): A couple of readers have reported that the results shown in this Colab notebook are not reproducible anymore. I have cross checked and confirmed that what they are saying is indeed true. My guess is that since I wrote this notebook, the libraries in Colab have been updated and the results are not reproducible anymore. I will try to fix this when I get time. In the meantime, if someone out there knows what the fix is, please message me and I will update it.
We will use Tensorflow + Keras. If you are running the code on Colab, there is no extra setup required. However, if you are running it on your local machine, please install the latest version of Tensorflow.
We will be implementing the idea presented in the paper Distilling the Knowledge in a Neural Network by Geoffrey Hinton, Oriol Vinyals and Jeff Dean.
We won’t exactly follow the architectural details presented in the paper. I tried to replicate the results published in the paper, but in vain. However, the idea of distillation is promising and I have got good results with a different set of architecture. So let’s dive in.
We will be describing the idea below as we make progress. But first of all let’s start with a very simple multi class classification problem.
We will use MNIST dataset. It comes in built with Keras, which simplifies the matter. Let’s create a function to load it. We will also normalize the dataset and one hot encode the labels.
Next, let’s create a simple CNN with two convolution layers of size 32 and 64, followed by a max pool layer. We will use the Keras functional API for this purpose.
Here’s the model summary.
Notice how we have defined the last two layers in the code above. The last dense layer named
logits has no activation function. We have separately defined an activation layer named
softmax. We could have simply done this in one step.
preds = layers.Dense(10, activation=’softmax’, name=’Softmax’)(x)
But there is a reason for doing it in two steps instead of one. We will discuss it in a while.
Let’s train this model for 15 epochs with a batch size of 512.
When I trained this on my Colab notebook, the training loss was
0.0232 and training accuracy was
Now that the model is trained, let’s see how it performs on the test set.
(On running the above code, we get the following output.)
Great! We got an accuracy of 0.9918. This means we are misclassifying 82 images.
Please note that the actual model performance may differ with every run. So, when you run this code on your machine, you may get different results than what is mentioned here.
Softmax and temperature
We have already discussed this in the first section of this post. But now we have real data, real model and real probabilities. Let’s see how the output of this model looks like with actual and softened probabilities.
First of all, let’s define the modified softmax function.
The inputs to the softmax layer are called logits. Let’s get these logits from the model. Remember in the code above, we had defined the final softmax layer in two steps instead of one. We had separated the
logits and the
softmax layer. Had we done it in a single step, it wouldn’t have been possible for us to get the
To get the logits, we will have to remove the softmax layer from our model. Once the softmax model is removed, we can pass in the data as input and get the logits as output.
Let’s pass these logits to our newly defined softmax function. Initially, let’s keep
temperature=1. This will give us the actual (un-softened) probabilities.
Let’s take a look at the actual probability of the first image. Before that, let’s plot the image to see for ourself which handwritten digit is present in it.
OK! So the first image is the number
5! And what have our model predicted it to be?
Cool! So our model predicted it correctly. But the incorrect probabilities are all very very small numbers.
So now, let’s soften the probabilities by increasing the softmax temperature to
4 and see how they look like.
Looking at the softened probabilities, it’s quite clear that this image of
5 resembles most with
3 and second most with
9. It resembles the least with
4. As humans, we can derive this knowledge by simply looking at the image.
Let’s do a similar exercise for some other image. Let’s take the 17th image from the training set.
The softened probabilities clearly states that this image of
2 resembles most with
7, and second most with
3. It least remembles with
Transferring dark knowledge
We can now use this dark knowledge to train a much smaller model.
First, let’s train a smaller model without any regularization on normal data (true probabilities instead of softened probabilities) and see how it performs.
We will create a model with two fully connected layers with 128 units in each.
Let’s train this model for 50 epochs with a batch size of 512.
This model’s accuracy on test data is
0.9785, which means it misclassifies
215 images. Also note that the model is heavily overfitting and it has already saturated.
Next, let’s train a model with the exact same architecture on the softened probabilities. The small model should also be trained with the same value of temperature. So we will have to first modify the model. Let’s add a temperature layer before the softmax layer in the small model.
Now let’s train this modified model for 100 epochs with a batch size of 512. Let’s also increase the learning rate to
Notice below that instead of passing
y_train, we are passing
softened_train_prob for training. For evaluation, we are still using the actual labels,
Ideally, we should use a held out transfer set for training the smaller model. We should also take the weighted average of the loss of the soft and actual targets for the loss function. While predicting, we should reset the temperature to
1. But we will skip these steps for now.
The test accuracy of this model is
0.9818, which is a significant improvement over
- The CNN that we trained had 1,199,882 params. The smaller fully connected model has 118,282 params. The fully connected model is 90% smaller than CNN!
- While CNN has an accuracy of
0.9918, smaller model’s accuracy is
0.9818. This means that we have reduced our model size by 90% by trading 1% in accuracy.
Let’s compare the two smaller models.
- The first smaller model quickly overfits, even when it is trained at a learning rate of
0.001. On the other hand, overfitting in the second model is much less even after 100 epochs. This is in spite of the fact that the second models trains at a much higher learning rate of
- The loss in the second model is still reducing. So, there is scope for further improving its accuracy by tuning its hyperparameters.
- One thing that is quite evident is that training on soft targets also acts as a regularizer.
This type of distillation is popularly known as the student teacher architecture. In our case, the CNN is a teacher model and the smaller one is a student model.
Once again, you can run the code presented in this post in your browser through this Colab Notebook.
What we have discussed in this blog post is just the tip of the iceberg. I haven’t included a lot of other stuff that is discussed in the paper. I’d highly recommend you to go and read the paper. It is available at https://arxiv.org/pdf/1503.02531.pdf. Geoffrey Hinton has also given a talk on this paper. It is available on YouTube. Slides for this talk are available here.