Using Transfer Learning to Classify Images with TensorFlow

t-SNE plot of CNN codes on CIFAR-10

This past weekend a new tutorial was added to the TensorFlow Github repo. Included are code and detailed explanations of how transfer learning works in TensorFlow. I decided to see if I could apply transfer learning to a new set of images that the original Inception-v3 model has not yet seen. I was quite surprised how well this technique works and thought I would share a few observations I learned.

Experimenting with a new data set

I’m hoping to be able to apply CNNs to various problems that I’m interested in however to get a good feel for TensorFlow I decided to use a well understood data set — CIFAR-10.

A random sampling of data from CIFAR-10

CIFAR-10 provides 60,000 images where each image belongs to a single class. The data is already well formed for the purpose of training a CNN so very little data munging is required. Successfully applying transfer learning to CIFAR-10 is a great starting point towards future applications.

The image recognition model included in TensorFlow is called Inception-v3. This model was trained for several weeks on multiple GPUs on ImageNet. ImageNet provides 1,000,000 images across 1,000 classes. For most people, training a model of this size is not going to be feasible.

Transfer learning requires less data

Transfer learning provides the opportunity to adapt a pre-trained model to new classes of data with several advantages. CIFAR-10 only provides 60,000 images (compared to 1M from ImageNet) and from my observations the model converges to a decent solution after training on just a fraction of the available training data.

Training and validation accuracy for the first 100 steps. The model achieves 80% validation accuracy after only 100 mini-batch iterations (or 1000 images).

This is great news if you’re working with a limited dataset.

Transfer learning does not require GPUs to train

Training across the full training set (40,000 images) took less than a minute on my Macbook Pro without GPU support. This is not entirely surprising though, as the final model is just a softmax regression. And actually, most of the time spent was upfront generating the CNN codes. Which brings me to my next point.

Pre-compute CNN Codes

I highly recommend writing a bit of extra code to pre-compute your pool_3 CNN codes as this will save you a ton of time in the long run. I started off by fetching pool_3 features on the fly and this would have resulted in hours of training time. Save the pool_3 weights upfront to save yourself time as you experiment with your model.

A bit of code to pre-compute CNN codes

Pre-computing the CNN codes took about 3–4 hours on my laptop but thankfully was just a one time cost. There is an outstanding request to accept batches of inputs which should eventually speed this part up.

(UPDATE: 2016–06–05: There was a patch released to address the issue with batches)

Visualize your CNN Codes

I actually had a lot of fun with this part and was surprised to see what a great plot this produced. I was very surprised to see how well the CNN Codes work for separating CIFAR-10 classes into distinct clusters. This is a great indication that a simple multinomial model should be able to fit this data.

The final plot is at the top and here is a snippet of code I used to draw the scatter plot using t-SNE and Seaborn.

Generating the t-SNE values took several hours. I can’t say exactly how many because I fell asleep :)

The Results?

The final softmax model achieved 87.7% accuracy on the test set which blew me away considering how little work was required on my part. This is a nice baseline if I want to try image distortions or hyperparameter tuning to eke out a few more percentage points.

Feel free to have a look at my code on GitHub.

Update 2016–06–16: Checkout my talk from the TensorFlow meetup on applying this strategy to an e-commerce dataset from

Like what you read? Give Scott Thompson a round of applause.

From a quick cheer to a standing ovation, clap to show how much you enjoyed this story.