Transfer learning with Convolutional Model in Tensorflow Keras
One of the most used techniques in the development of a deep learning model is transfer learning. Transfer learning can bring down the model training time from multiple days to a few hours, provided implemented efficiently.
In this blog post, we will walk through the crucial steps involved in training a convolutional model-based image classifier using transfer learning from a pre-trained model.
A pre-trained model is a saved network that was previously trained on a large dataset, typically on a large-scale image-classification task. You use the pre-trained model for transfer learning to customize this model to a given task.
The intuition behind transfer learning for image classification is that if a model is trained on a large and general enough dataset, this model will effectively serve as a generic model of the visual world. You can then take advantage of these learned feature maps without having to start from scratch by training a large model on a large dataset.
Let's go through the steps involved in transfer learning a model on a custom dataset.
NOTE: If you want to try these techniques in practice, try running this colab notebook.
Install the latest version of Tensorflow
!pip install -U tensorflow
Data preprocessing
We will use a dataset of dog and cat images to build the model.
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')
BATCH_SIZE = 64
IMG_SIZE = (224, 224)train_dataset = image_dataset_from_directory(train_dir,
shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)validation_dataset=image_dataset_from_directory(validation_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)
Let's visualize some samples from the dataset
class_names = train_dataset.class_names
plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
for i in range(16):
ax = plt.subplot(4, 4, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
To test the model performance we need a test dataset. Since in this dataset we don't have a separate test dataset, we will split the validation dataset into validation and test (25% of validation dataset) dataset.
val_batches = tf.data.experimental.cardinality(validation_dataset)test_dataset = validation_dataset.take(val_batches // 4)validation_dataset = validation_dataset.skip(val_batches // 4)
Configure the dataset for performance
Use buffered prefetching to load images from disk without having I/O become blocking.
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset=validation_dataset.prefetch(buffer_size=AUTOTUNE)test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)
Use data augmentation
Since we don't have a vast number of samples in the dataset; we can use data augmentation to create artificial varieties to achieve better generalization performance. Randomly chosen data augmentation techniques are useful for training only; for testing, deterministic data augmentation can be used.
data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),tf.keras.layers.experimental.preprocessing.RandomFlip('vertical'),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),])
Let’s repeatedly apply these layers to the same image and see the result.
for image, _ in train_dataset.take(1):
plt.figure(figsize=(10, 10))
first_image = image[0]
for i in range(16):
ax = plt.subplot(4, 4, i + 1)
augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
plt.imshow(augmented_image[0] / 255)
plt.axis('off')
Create the base model from the pre-trained convolutional network
We will use EfficientNetB0 model for the transfer learning task. This is pre-trained on the ImageNet dataset, a large dataset consisting of 1.4M images and 1000 classes.
# Create the base model from the pre-trained model EfficientNet B0IMG_SHAPE = IMG_SIZE + (3,)
base_model=tf.keras.applications.EfficientNetB0(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')
This feature extractor converts each 224x224x3
image into a 7x7x1280
block of features.
Transfer learning
Since the original model was trained on the ImageNet dataset, it has samples of cats and dogs; so we don't need to retrain the entire model, but a select few top layers to improve the overall performance.
Un-freeze the top layers of the model
All you need to do is unfreeze the base_model
and set the bottom layers to be un-trainable. Then, you should recompile the model (necessary for these changes to take effect), and resume training.
base_model.trainable = True
print("Number of layers in the base model: ", len(base_model.layers))
fine_tune_at = 230
# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False
Use a global averaging layer to pool 7x7 feature map before feeding it into the dense classification layer.
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
Add a classification layer, for binary classification number of node, is 1.
prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
It is important to note that during the transfer learning phase we don't want to update the moving mean and moving standard deviation of the Batch Normalization layer. To disable that pass training=False
to the base model.
Final model
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = data_augmentation(inputs)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
Compile the model
As you are training a much larger model and want to readapt the pre-trained weights, it is important to use a lower learning rate at this stage. Otherwise, your model could overfit very quickly.
base_learning_rate = 0.0001
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),metrics=['accuracy'])model.summary()
Train the model
total_epochs = 20
history=model.fit(train_dataset,epochs=total_epochs,initial_epoch=0,validation_data=validation_dataset)
Let’s take a look at the learning curves of the training and validation accuracy/loss when fine-tuning the last few layers of the EfficientNetB0 base model and training the classifier on top of it. The validation loss is much higher than the training loss, so you may get some overfitting.
You may also get some overfitting as the new training set is relatively small and similar to the original EfficientNetB0 datasets.
After transfer learning, the model nearly reaches 98% accuracy on the validation set.
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()
Evaluation and prediction
Finally, you can verify the performance of the model on new data using the test set.
loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
And now you are all set to use this model to predict if your pet is a cat or dog.
# Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)plt.figure(figsize=(10, 10))
for i in range(16):
ax = plt.subplot(4, 4, i + 1)
plt.imshow(image_batch[i].astype("uint8"))
plt.title(class_names[predictions[i]])
plt.axis("off")