Application of Transfer Learning to solve Real-World Problems in Deep Learning
Introduction
One of the most important properties of Deep Neural Networks that has made them so popular in every domain is their peculiar ability to approximate any complex function. But Neural Networks are also voracious feeders of data and we need high computation power to train neural networks on large datasets. The problem of over-fitting is very common in Neural Networks when we do not have enough amount of data to train. This may obstruct most of us from using deep learning to solve our problems due to the lack of enough amount of data or high computation resources. Transfer learning can be a great help for us in such cases.
Transfer Learning can be simply understood as the act of utilizing a model trained on a certain domain as the starting point for training on another domain. The more related the domains are, the transfer learning approach performs better. More specifically, transfer learning deals with training machine learning models by using weights of pre-trained state-of-the-art models rather than randomly initialized weights. Transfer learning can be understood with a simple analogy of how we humans learn. We humans don’t learn everything from scratch. Every time, we need to learn new things we use our prior experiences and knowledge. We utilize our prior knowledge and experiences as a base while learning new things. When a doctor wants to learn literature in his native language, the doctor doesn’t need to start everything from scratch. The doctor already knows the alphabet, words, and structure of sentences. He uses that prior knowledge to learn literature which reduces the effort to learn. We can also use a similar concept in deep learning. We can use pre-trained state-of-the-art deep learning models as a starting point while training on our data. If we do so, our model can easily learn patterns in a new domain with very little effort and little amount of training data.
Need for Transfer Learning:
Training the neural network from scratch to obtain accurate results is a very tedious task. We need to collect millions of labeled training examples for a neural network to perform well. It takes a lot of time and effort to collect such a vast amount of data. In some cases, it is very difficult to collect this much amount of data due to a lack of availability and proprietary issues. Similarly, the performance of the model also gets degraded badly when it is used for prediction in a different domain than the domain it was trained on as it does not how to deal with the data it has never seen before. Even the state-of-the-art models exhibiting supernatural power in a particular domain/task easily get break down when their domain is transferred. This clearly depicts the necessity of transfer learning, which aims to break boundaries of specific domains, the model was trained on and tries to leverage knowledge on particular domain/task to solve problems of other domain/task. Andrew NG in a tutorial at NIPS 2016 entitled Nuts and bolts of building AI applications using Deep Learning mentioned:
“After supervised learning — Transfer Learning will be the next driver of ML commercial success.”
Transfer Learning with Practical Example of Real World Application
In this section, we will see how Transfer Learning can be applied to Real-World Problems. We will build a classifier that will classify images of food items into seven different categories. I have made a dataset that consists of images of seven different food-items popular in Nepal. The Food items are Chhoila, Rice, Kathi-Roll, Laphing, Mo: Mo, Paani-Puri, and Yomari. I collected random images of these items from the INTERNET and placed them into the folders with the name of these items. I used the python package imgaug to augment images. As this post is focused on Transfer Learning, I won’t discuss much about data collection and preprocessing steps. If you want to explore more, you can check my GitHub repository.
Using custom Neural Network Architecture for classification:
We will start by building a simple neural network with four convolution layers accompanied by max-pooling layers for down-sampling of features. A flatten
layer is added after convolution layers to flatten out feature maps from the fourth convolution layer. We have placed two dense layers after flatten
layer and an output layer. The output layer will have 7 nodes as no out classes in our dataset is 7
.
From the above plot of Accuracy
values of training and validation sets, it is clear that the model is severely suffering from over-fitting. To get rid of this problem, we can try adding dropout
and decreasing the number of nodes in dense layers.
Use Regularization techniques in custom Neural Network Architecture to get rid of over-fitting
Dropout is a method of regularization in neural networks that randomly eliminates the neurons during the training of the neural network. By randomly eliminating neurons, we reduce the interdependency among neurons which helps to reduce overfitting. We will add a dropout of 0.3
after each convolution layer and Fully Connected layer. This will randomly eliminate 30% of the neurons in each layer except the output layer by setting their output to zero.
To make the learning of the model more simple, we can also reduce the number of nodes. We will reduce the number of nodes in the first Dense layer to 128
from 256
and from 512
to 128
in the second layer. This might prevent the model from learning complex mapping functions specific to the training set.
The accuracy plot reflects the model is still suffering from overfitting problem. The reason for overfitting because of less amount of training data. To solve this problem, we can leverage transfer learning with state-of-the-art pre-trained models.
Transfer Learning with Pre-trained state-of-the-art models
We can overcome the problem of overfitting using the concept of Transfer Learning. In neural network architectures used for computer vision tasks. the initial layers detect edges, middle layers detect shapes, and domain/task specific features are only detected by final layers. Hence, we can use the weights of state-of-the-art models for early and middle layers and only tune the weights for the final layers.
There are various state-of-the-art architectures available for Computer Vision. Some of the popular ones are:
However, I’m gonna use MobileNetv2
as it is a lightweight model that can run smoothly even on mobile phones.
Understanding MobileNetv2
MobileNetv2
is a lightweight model released by Google in 2018 to power the next generation of mobile vision applications.
MobileNetv2
Architecture MobileNetv2
as a Feature Extractor
We start the journey of transfer learning by freezing all the convolution layers of pre-trained MobileNetv2
on the ImageNet
dataset. As our dataset, in this case, has only 7
output classes while ImageNet
has 1000
output classes, we remove the output layer of MobileNetv2
and insert the output layer with 7
nodes. Similarly, we also add 2
fully-connected layers to map representations given by convolution blocks to our output classes. A flatten
layer is added between convolution blocks and fully-connected layers to convert the representations into 1-D array
.
It is amazing how the transfer learning spikes the accuracy
of the validation set from 48%
to 88%
. However, the model is still overfitting. This is because of a very less amount of dataset.
Training Inner layers of base-model
We already used MobileNetv2
as a feature extractor and saw how it booms the performance. We can also experiment with training inner-layers of the base model.We will freeze all the blocks of MobileNet
before block_15
and fine-tune the blocks after block_15
. This also somewhat increases our accuracy by a certain amount.
Conclusion:
From the above experiments and results, we can conclude that Transfer Learning can be a great friend to deal with real-world problems with several constraints like lack of enough data and computation resources. Let’s hope Transfer Learning will bring many advances in the field of machine learning in the days to come.
Check the Project at GitHub:
P.S. I am just a naive ML practitioner. Feel free to reach me out if there are mistakes in the blog post. Leave your feedback in the comment section or you can email me.
Get in touch with me: