Using Transfer Learning and Bottlenecking to Capitalize on State of the Art DNNs
As a data scientist, I’m really interested in how accessible deep learning is for professionals. There needs to be a certain level of abstraction available in whatever deep learning framework is being used, so that the style or structure of your code matches your thinking. If you’re too bogged down with low-level details, it can be hard or impossible to be creative when you’re problem solving.
Currently, TensorFlow is working hard to make these high-level APIs available along with client libraries for your programming language of choice. The 2017 TensorFlow Dev Summit was amazing to watch — I couldn’t be more excited about the platform’s future. Google is also making the integration between TensorFlow and Keras, a highly Python API for Tensorflow and Theano, as smooth as possible. However, the fact remains that even if writing a deep neural network is easy, it may not be feasible.
There as been a lot of research done regarding network architecture. There are new and better architectures every week, but there are also existing models that work well — well enough to put into production! If you’re a data scientist or analytics professional, you probably don’t need to reinvent the wheel even if you could. This is where transfer learning comes into play.
Transfer Learning 101
So, what is transfer learning? Imagine Google built a convolutional neural network and then trained it in a massive amount of data using powerful GPUs until it was excellent at classifying images from that train/test dataset. This would look something like InceptionV3. Now imagine that you have a different dataset of images with different classes. In the lower convolutional layers of InceptionV3, the weights are finely tuned to do things like edge detection exceptionally well. Finally, imagine using the lower half of InceptionV3 to do the heavy lifting (edge detection, etc) and then fed those outputs into your network with the proper number of classes. You would be transferring the learning from InceptionV3 to your new model.
A Case Study in Transfer Learning
For this article, I am going to stick to the domain of image classification and take a look at some different successful models in that space. Specifically, we’ll take a look at the InceptionV3, VGG16, and ResNet50 networks. All of these networks have been extremely successful in the ImageNet Large Scale Visual Recognition Challenge at one point or another. They are all available through Keras applications. I’ll be using these networks to classify the CIFAR-10 dataset.
There are four possibilities when thinking about transfer learning:
1) New data set is small and similar to the previous: Since the new data set is small, you run the risk of overfitting if you retrained everything. Instead, slice off the last fully connected layer and replace with with a new fully connected layer with the appropriate output size. This makes sense because the similarity of the obersations (i.e. pictures) means both the low-level (e.g. edges) and high-level features (e.g. shapes) will be similiar. Freeze the weights before the last layer and retrain!
2) New data set is large and similar to the previous: Since there is more data, there is less risk of overfitting by retraining. Freeze the low-level feature weights and retrain the high-level features to get a better generalization. Don’t forget to replace the last fully connected layer! *Optional: If your data set is large enough to handle it, you can initialize all the layers with their previous weights/biases and retrain the entire network.*
3) New data set is small and different than the previous: This is the most difficult situation to deal with. Intuitively, we know that the previous network is finely-tuned at each layer. However, we do not want any of the high-level features and we cannot afford to retrain them because we could overfit. Instead, remove all of the fully connected layers and all of the high-level convolutional layers. All that should remain are the first few low-level convolutional layers. Place a fully connected layer with the correct number of outputs, freeze the rest of the layers, and retrain.
4) New data set is large and different than the previous: Retrain the entire network. It’s usually a good idea to instantiate the previous model’s weights/biases to speed up training (lot’s of the low-level convolutions will have similiar weights/biases). Don’t forget to replace the fully connected output layer.
In our case, the CIFAR10 and ImageNet data sets are very similiar and large. CIFAR10 has 50,000 training observations and 10,000 for validation. We will use approach #2 for all of our transfer learning examples in this project.
import tensorflow as tf
from keras.applications.inception_v3 import InceptionV3
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras import backend as K
#import inception with pre-trained weights. do not include fully #connected layers
inception_base = InceptionV3(weights='imagenet', include_top=False)
What I’ve done in the code above is to instantiate an instance of the InceptionV3 model. Let’s take a look at the parameters and arguments that I assed.
weights='imagenet' will load the final weights after the model trained.
include_top=False is the key for transfer learning. This instantiates the model without its fully connected layers. These layers hold all of the information for sorting the convolutional information into the correct classes. Since ImageNet has 200 different classes and CIFAR-10 has 10 (who would have guessed?) we don’t need any of the information in these dense layers. Instead, we will train new dense layers, the last one having 10 nodes and a softmax activation.
# add a global spatial average pooling layer
x = inception_base.output
x = GlobalAveragePooling2D()(x)
# add a fully-connected layer
x = Dense(512, activation='relu')(x)
# and a fully connected output/classification layer
predictions = Dense(10, activation='softmax')(x)
# create the full network so we can train on it
inception_transfer = Model(input=inception_base.input, output=predictions)
Read this post about global average pooling. Essentially, this layer will take the average of all the feature maps from the last convolutional layer. Whereas those maps used to be fed into fully connected layers (we didn’t import them), they will now feed into our new fully connected layer. What is important to know about global average pooling is that it allows the network to accept any Tensor/image size, instead of expecting the size that it was originally trained on.
Training your new model
The model is almost ready! There is just one thing left to do. We need to freeze the covolutional layers. This is a critical step. We know InceptionV3 is already an excellent image classification network — those lower layers do not need to be retrained.
for layer in inception_base.layers:
layer.trainable = False
All that’s left to do is set the loss and train operations! Be sure to check out the full implementation along with all of the roadblocks and caveats in GitHub repository.
If you have a large dataset or a large network (ResNet152 is a whopping 152 layers!), training the new fully connected layers can still be very expensive. Bottlenecking is a technique that allows us to speed up this process by excluding the pre-trained convolutional layers from the training process. Instead, we will ask those layers to predict on the new data set. Predicting is much faster than training (there is no back propagation) and the raw output can then be used as input to train the fully connected layers. Then, the convolutional layers and dense layers are concatenated.
Here is another definition from the TensorFlow website:
‘Bottleneck’ is an informal term we often use for the layer just before the final output layer that actually does the classification. This penultimate layer has been trained to output a set of values that’s good enough for the classifier to use to distinguish between all the classes it’s been asked to recognize. That means it has to be a meaningful and compact summary of the images, since it has to contain enough information for the classifier to make a good choice in a very small set of values. The reason our final layer retraining can work on new classes is that it turns out the kind of information needed to distinguish between all the 1,000 classes in ImageNet is often also useful to distinguish between new kinds of objects.
If you’re interested in the implementation, be sure to check out my GitHub. Also, you can check out my website, https://galenballew.github.io/, for a D3 visualization of how InceptionV3, VGG16, and ResNet50 stack up against each other when transfer learning on CIFAR-10.