How to use transfer learning and fine-tuning in Keras and Tensorflow to build an image recognition system and classify (almost) any object

Greg Chu
Deep Learning Sandbox
7 min readMay 1, 2017
Example images from the CompCars dataset (163 car makes, 1713 car models)

Go straight to the code on GitHub here!

In the last post, I covered how to use Keras to recognize any of the 1000 object categories in the ImageNet visual recognition challenge. More often than not, however, the categories we are interested in predicting are not in that list.

So what do we do if want to classify between different models of sunglasses? or shoes? or facial expressions? or different models of cars? or different types of lung disease in X-ray images?

In this post, I will show you how to use transfer learning and fine-tuning to identify any customizable object categories! To recapitulate, here is the blog post series we’ll be following:

  1. Build an image recognition system for a 1000 everyday object categories (ImageNet ILSVRC) using Keras and TensorFlow
  2. Build an image recognition system for any customizable object categories using transfer learning and fine-tuning in Keras and TensorFlow (this post)
  3. Build a real-time bounding-box object detection system for hundreds of everyday object categories (PASCAL VOC, COCO)
  4. Build a web service for any image recognition or object detection system

Why use transfer learning/fine tuning?

It’s well known that convolutional networks require significant amounts of data and resources to train. For example, the ImageNet ILSVRC model was trained on 1.2 million images over the period of 2–3 weeks across multiple GPUs.

It has become the norm, not the exception, for researchers and practitioners alike to use transfer learning and fine-tuning, that is, transferring the network weights trained on a previous task like ImageNet to a new task.

And it does astoundingly well! Razavian et al (2014) showed that by simply using the features extracted using the weights from an ImageNet ILSVRC trained model, they achieved state-of-the-art or near state-of-the-art performance on a large variety of computer vision tasks.

There are two approaches we can take:

  1. Transfer learning: take a ConvNet that has been pre-trained on ImageNet, remove the last fully-connected layer, then treat the rest of the ConvNet as a feature extractor for the new dataset. Once you extract the features for all images, train a classifier for the new dataset.
  2. Fine-tuning: replace and retrain the classifier on top of the ConvNet, and also fine-tune the weights of the pre-trained network via backpropagation.

Which to use?

There are two main factors that will affect your choice of approach:

  1. Your dataset size
  2. Similarity of your dataset to the pre-trained dataset (typically ImageNet)
Taken from http://cs231n.github.io/

*you should also experiment with training from scratch as well.

The figure above and the bullets below describe some general advice for when to choose which approach.

  • Similar & small dataset: avoid overfitting by not fine-tuning the weights on a small dataset, and use extracted features from the highest levels of the ConvNet to leverage dataset similarity.
  • Different & small dataset: avoid overfitting by not fine-tuning the weights on a small dataset, and use extracted features from lower levels of the ConvNet which are more generalizable.
  • Similar & large dataset: with a large dataset we can fine-tune the weights with less of a chance to overfit the training data.
  • Different & large dataset: with a large dataset we again can fine-tune the weights with less of a chance to overfit.

Data augmentation

A powerful and common tool for increasing the dataset size and model generalizability is data augmentation. In fact, every competition winning ConvNet employs the use of data augmentation. Essentially, data augmentation is the process of artificially increasing the size of your dataset via transformations.

Most deep learning libraries have ready-made functions for typical transformations. For our image recognition system, the task you have is to decide which transformations make sense for your data (for example, X-ray images should probably not be rotated by more than 45 degrees because that would mean there was an error in the image acquisition step).

Data augmentation via horizontal flipping and random cropping

Example transformations: Pixel color jitter, rotation, shearing, random cropping, horizontal flipping, stretching, lens correction.

Transfer learning and fine-tuning implementation

Go straight to the code on GitHub here!

Data preparation

Sample images from Kaggle’s Cat vs Dog dataset

We’ll use Kaggle’s Dogs vs Cats dataset as our example, and setup our data with a training directory and a validation directory in this manner:

train_dir/
dog/
cat/
val_dir/
dog/
cat/

Implementation

Let’s start by preparing our data generators:

train_datagen =  ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
test_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=30,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
train_generator = train_datagen.flow_from_directory(
args.train_dir,
target_size=(IM_WIDTH, IM_HEIGHT),
batch_size=batch_size,
)
validation_generator = test_datagen.flow_from_directory(
args.val_dir,
target_size=(IM_WIDTH, IM_HEIGHT),
batch_size=batch_size,
)

Recall from our previous blog post on image recognition the importance of the preprocessing step. This is set by preprocessing_function = preprocess_input where preprocess_input is from the keras.applications.inception_v3 module.

The rotation, shifting, shearing, zooming, and flipping parameters signal ranges for their respective data augmentation transformations.

Next, we’ll instantiate the InceptionV3 network from the keras.applications module.

base_model = InceptionV3(weights='imagenet', include_top=False)

We use the flag include_top=False to leave out the weights of the last fully connected layer since that is specific to the ImageNet competition, from which the weights were previously trained. We’ll add and initialize a new last layer:

def add_new_last_layer(base_model, nb_classes):
"""Add last layer to the convnet
Args:
base_model: keras model excluding top
nb_classes: # of classes
Returns:
new keras model with last layer
"""
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(FC_SIZE, activation='relu')(x)
predictions = Dense(nb_classes, activation='softmax')(x)
model = Model(input=base_model.input, output=predictions)
return model

GlobalAveragePooling2D converts the MxNxC tensor output into a 1xC tensor where C is the # of channels.

Then we add on a fully-connected Dense layer of size 1024, and a softmax function on the output to squeeze the values between [0,1].

In this program, I’ll demonstrate both transfer learning and fine-tuning. You can use either or both if you like.

  1. Transfer learning: freeze all but the penultimate layer and re-train the last Dense layer
  2. Fine-tuning: un-freeze the lower convolutional layers and retrain more layers

Doing both, in that order, will ensure a more stable and consistent training. This is because the large gradient updates triggered by randomly initialized weights could wreck the learned weights in the convolutional base if not frozen. Once the last layer has stabilized (transfer learning), then we move onto retraining more layers (fine-tuning).

Transfer learning

def setup_to_transfer_learn(model, base_model):
"""Freeze all layers and compile the model"""
for layer in base_model.layers:
layer.trainable = False
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])

Fine-tune

def setup_to_finetune(model):
"""Freeze the bottom NB_IV3_LAYERS and retrain the remaining top
layers.
note: NB_IV3_LAYERS corresponds to the top 2 inception blocks in
the inceptionv3 architecture
Args:
model: keras model
"""
for layer in model.layers[:NB_IV3_LAYERS_TO_FREEZE]:
layer.trainable = False
for layer in model.layers[NB_IV3_LAYERS_TO_FREEZE:]:
layer.trainable = True
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy')

When fine-tuning, it’s important to lower your learning rate relative to the rate that was used when training from scratch (lr=0.0001), otherwise, the optimization could destabilize and the loss diverge.

Training

Now we’re all set for training. Use fit_generator for both transfer learning and fine-tuning.

history = model.fit_generator(
train_generator,
samples_per_epoch=nb_train_samples,
nb_epoch=nb_epoch,
validation_data=validation_generator,
nb_val_samples=nb_val_samples,
class_weight='auto'
)
model.save(args.output_model_file)

We’ll use an Amazon EC2 g2.2xlarge instance for training. If you’re unfamiliar with AWS, check out this tutorial (instead of using their prescribed AMI, search the community AMIs for “deep learning” — for this post, I used ami-638c1eo3in US-West-Oregon).

We can plot the training accuracies and loss using the history object

def plot_training(history):
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(len(acc))

plt.plot(epochs, acc, 'r.')
plt.plot(epochs, val_acc, 'r')
plt.title('Training and validation accuracy')

plt.figure()
plt.plot(epochs, loss, 'r.')
plt.plot(epochs, val_loss, 'r-')
plt.title('Training and validation loss')
plt.show()

Prediction

Now that we have a saved keras.model we can modify the same predict() function we wrote in the last blog post to predict the class of a local image file or any file via a web URL. Checkout the github for the full program.

python predict.py --image dog.001.jpg --model dc.model
python predict.py --image_url https://goo.gl/Xws7Tp --model dc.model

We’re done!

As an example, I trained a model on the dogs-vs-cats dataset using 24000 images for training and 1000 images for validation for 2 epochs. Even after only 2 epochs, the performance is pretty high:

Metrics log after 2 epochs

Download the trained model here*

*model compatible with keras==1.2.2

Examples

python predict.py --image_url https://goo.gl/Xws7Tp --model dc.model

python predict.py --image_url https://goo.gl/6TRUol --model dc.model

Stay tuned for the next post in the series:

  1. Build an image recognition system for a 1000 everyday object categories (ImageNet ILSVRC) using Keras and TensorFlow
  2. Build an image recognition system for any customizable object categories using transfer learning and fine-tuning in Keras and TensorFlow (this post)
  3. Build a real-time bounding-box object detection system for hundreds of everyday object categories (PASCAL VOC, COCO)
  4. Build a web service for any image recognition or object detection system

If you have any questions contact me at greg.ht.chu@gmail.comor message me on LinkedIn!

If you enjoyed this, please click the ❤

--

--