Understanding Transfer Learning & Image Augmentation

Naivedh Shah
Oct 29 · 6 min read

Have you ever taken part in an image classification competition and felt that your model isn’t as good as the person who is on top? Then I think this blog is for you.

Table of Content

  • What is Transfer Learning?
  • Preprocessing
  • Image Augmentation
  • Transfer learning using ResNet101
  • Evaluation
  • End Notes

What is Transfer Learning?

Transfer learning is the use of a pre-trained model to solve a new problem or create a new model.


In this step, we will create an images directory and unzip our data in it.

!mkdir images
!unzip code_warriors_game_of_data_ai_challenge-dataset.zip -d images/

Now, we will import all the required libraries.

import os, shutil
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from keras.preprocessing import image
from keras.applications.resnet import ResNet101
from keras.layers import Dense
from keras.models import Model, Sequential

In the next step, we will split the images into train images and validation images.

Step 1: Create a val_images directory to transfer images.

if not os.path.isdir("val_images"):

Step 2: Create a list of classes that the data can be classified into.

classes = ['Bread','Dairy product','Dessert','Egg','Fried food','Meat','Noodles-Pasta','Rice','Seafood','Soup','Vegetable-Fruit']

Step 3: We will create sub-directories to store images of specific categories. The below code goes through all the elements of the list and created a folder for the class if there is none.

for c in classes:
if not os.path.isdir("val_images/" + c):
os.mkdir("val_images/" + c)

Step 4: This step might be a little difficult to understand to understand so bear with me.

In the below code we assign 0.9 to variable split as we want to split the data in the ratio 90:10 (train : validate). In the next line, we loop through all the sub-folder in the train folder and create a path variable for it. The os.listdir() will return all the file names in the specific folder. We can obtain the number of images using the len() function and by multiplying it with the split variable we can obtain the split size i.e. the number of train images that we need.

In the following step, we create a variable files_to_move and extract the name of files that are indexed from split_size till the end i.e. 10% of data. Finally, we will create a source path and destination path using the join() function and will move them using the move() function by shutil.

Suggestion: If you get confused use print(variable) at any step this might help you get a better understanding.

split = 0.9
for dirc in os.listdir("/content/images/train"):
path = "/content/images/train/" + dirc
images = os.listdir(path)
split_size = int(len(images)*split)
files_to_move = images[split_size:]
for f in files_to_move:
src = os.path.join(path,f) #path+file
dest = os.path.join("val_images/",dirc)

To check that what we tried above has been accomplished we will check the contents of the directories.

We will loop through each item in train directory and print the number of images of each category.

for dirc in os.listdir("/content/images/train"):
path = "/content/images/train/" + dirc
img = os.listdir(path)
print(dirc, len(img))

We will loop through each item in val_images directory and print the number of images of each category.

for dirc in os.listdir("val_images/"):
path = "val_images/" + dirc\
img = os.listdir(path)
print(dirc, len(img))

Image Augmentation

Image augmentation is a way to increase the data that we already have by creating modified versions of it. In the below image you can see from one Cat image we can create multiple images by changing the width, height, zoom, shear, etc.

Image Augmentation

So, to implement this we will create a ImageDataGenerator object for training data and add properties based on which new images will be created like rotation_range, width_shift_range, height_shift_range, shear_range, zoom_range, horizontal flip.

train_gen = image.ImageDataGenerator(rotation_range=25,
width_shift_range = 0.3,
height_shift_range = 0.25,
shear_range = 0.2,
zoom_range = 0.3,
horizontal_flip = True

We will also create an ImageDataGenerator object for validation data but we won’t pass any properties because we do not want to generate images based on it as it is only for validation.

val_datagen = image.ImageDataGenerator()

In the upcoming step, we will create new images from the object that we just created. We are using flow_from_directory you ca also use flow_from_dataframe as per the requirements. We will pass the directory containing training images. The target size is specified (224,224) because we will use ResNet model which is trained on images of the same size.

train_generator = train_gen.flow_from_directory(
target_size = (224,224),
class_mode = "categorical",
shuffle = True,
batch_size = 32

We will also use val_datagen object to process image as per the requirement.

val_generator = val_datagen.flow_from_directory("val_images/",
target_size = (224,224),
class_mode = 'categorical',

Transfer learning using ResNet101

We will now create a ResNet101 object, the include_top is True as we want the final Dense Layer and the weights parameter is set to imagenet so we can get pre-trained weights.

resnet = ResNet101(include_top= True, weights= "imagenet")

In the upcoming step, we are extracting top n-2 resnet layers and adding a Dense Layer at the end. The activation is set to softmax as we want to perform classification and parameter 11 is specified because we have 11 classes.

#get top n-2 layers
x = resnet.layers[-2].output
fun = Dense(11, activation = "softmax")(x)

Now, we will create our model which has the inputs of the resnet model and the output fun that we just created.

model = Model(inputs=resnet.input, outputs = fun)

In this next step, we will freeze all layers except the last 30 layers i.e. we will make the trainable parameter false so the learned weights won’t change. We are training the last 30 layers as the ResNet101 model is trained on general data i.e. ImageNet whereas we have food data so our model needs to learn accordingly.

The model.compile method is used to compile the model before training. Here we are using sgd i.e. Stochastic Gradient Descent you can also use Adam optimizer or any other optimizer. We have taken loss as categorical_crossentropy, you can take any other loss function as per the requirement. The metrics that we will be using would be accuracy.

# Freeze layers
for l in model.layers[:-30]:
l.trainable = False
model.compile(optimizer="sgd", loss = "categorical_crossentropy",

The model.summary() method is used to get a look at all the layers. You can try it out, I want be able to show it as an image.


Next comes the training part, the fit method is used to train our model. We will pass train_generator for training, the steps_per_epoch can be calculated as “the number of training images/batch size” and the validation_steps can be calculated as “number of validation images/batch size” we will pass val_generator in validation_data for validation purpose.

hist = model.fit(train_generator,


We will use .evaluate() method to get the accuracy for our model. I would suggest you play with the hyper-parameters and try to improve the accuracy to 90% or above.


End Notes

In this blog we learned transfer learning using ResNet101, I would suggest you try different models such as VGG, DenseNet, Xception, MobileNet, and many others.

For more blogs on Machine Learning and Data Science do follow me and let me know if there is any topic that you would like to know more about.

Hey Readers, thank you for your time. If you liked the blog, don’t forget to appreciate it with a clap👏 and if in case you loved ❤ it, you can give 50 👏

Data Science Enthusiast | ML Enthusiast | TCS CA | Coding Blocks CA | Blogger | Community Member | Public Speaker

If you have any queries or suggestions feel free to contact me on



Sign up for Analytics Vidhya News Bytes

By Analytics Vidhya

Latest news from Analytics Vidhya on our Hackathons and some of our best articles! Take a look

By signing up, you will create a Medium account if you don’t already have one. Review our Privacy Policy for more information about our privacy practices.

Check your inbox
Medium sent you an email at to complete your subscription.

Naivedh Shah

Written by

Data Science Enthusiast | ML Enthusiast | TCS CA | Coding Blocks CA | Blogger | Community Member | Public Speaker

Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Naivedh Shah

Written by

Data Science Enthusiast | ML Enthusiast | TCS CA | Coding Blocks CA | Blogger | Community Member | Public Speaker

Analytics Vidhya

Analytics Vidhya is a community of Analytics and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store