Using Transfer Learning to Classify Images with TensorFlow
This past weekend a new tutorial was added to the TensorFlow Github repo. Included are code and detailed explanations of how transfer learning works in TensorFlow. I decided to see if I could apply transfer learning to a new set of images that the original Inception-v3 model has not yet seen. I was quite surprised how well this technique works and thought I would share a few observations I learned.
Experimenting with a new data set
I’m hoping to be able to apply CNNs to various problems that I’m interested in however to get a good feel for TensorFlow I decided to use a well understood data set — CIFAR-10.
CIFAR-10 provides 60,000 images where each image belongs to a single class. The data is already well formed for the purpose of training a CNN so very little data munging is required. Successfully applying transfer learning to CIFAR-10 is a great starting point towards future applications.
The image recognition model included in TensorFlow is called Inception-v3. This model was trained for several weeks on multiple GPUs on ImageNet. ImageNet provides 1,000,000 images across 1,000 classes. For most people, training a model of this size is not going to be feasible.
Transfer learning requires less data
Transfer learning provides the opportunity to adapt a pre-trained model to new classes of data with several advantages. CIFAR-10 only provides 60,000 images (compared to 1M from ImageNet) and from my observations the model converges to a decent solution after training on just a fraction of the available training data.
This is great news if you’re working with a limited dataset.
Transfer learning does not require GPUs to train
Training across the full training set (40,000 images) took less than a minute on my Macbook Pro without GPU support. This is not entirely surprising though, as the final model is just a softmax regression. And actually, most of the time spent was upfront generating the CNN codes. Which brings me to my next point.
Pre-compute CNN Codes
I highly recommend writing a bit of extra code to pre-compute your pool_3 CNN codes as this will save you a ton of time in the long run. I started off by fetching pool_3 features on the fly and this would have resulted in hours of training time. Save the pool_3 weights upfront to save yourself time as you experiment with your model.
Pre-computing the CNN codes took about 3–4 hours on my laptop but thankfully was just a one time cost. There is an outstanding request to accept batches of inputs which should eventually speed this part up.
(UPDATE: 2016–06–05: There was a patch released to address the issue with batches)
Visualize your CNN Codes
I actually had a lot of fun with this part and was surprised to see what a great plot this produced. I was very surprised to see how well the CNN Codes work for separating CIFAR-10 classes into distinct clusters. This is a great indication that a simple multinomial model should be able to fit this data.
The final plot is at the top and here is a snippet of code I used to draw the scatter plot using t-SNE and Seaborn.
The Results?
The final softmax model achieved 87.7% accuracy on the test set which blew me away considering how little work was required on my part. This is a nice baseline if I want to try image distortions or hyperparameter tuning to eke out a few more percentage points.
Feel free to have a look at my code on GitHub.
Update 2016–06–16: Checkout my talk from the TensorFlow meetup on applying this strategy to an e-commerce dataset from Gilt.com.