Tackling Multi-class Image Classification Problem with Transfer Learning using PyTorch

Aneesh Dalvi
Walmart Global Tech Blog
6 min readFeb 8, 2021

Image Classification is a Supervised Learning problem that can be resolved by training a model to recognize images. The objective of this classification is to identify and digitally analyze features of an image. It is fascinating to know that machines can classify images that even humans have a hard time classifying. Image Classification can have applications in Retail, Healthcare, Security, Automotive industry, almost in every field.

Source

Here we will attempt to solve an Image Classification problem and generalize the approach to solve similar problems. As our data is all images, the best way to solve this would be to build a Convolution Neural Network (CNN). We will develop our own CNN here using a pre-trained ResNet34 model with PyTorch library, which is a very popular library to build Deep Learning models.

I have chosen Intel Image Classification Dataset which is taken from the Kaggle website. You can find many such interesting datasets on Kaggle.

Before starting to work on the problem, it is always a good idea to take a look at our dataset. We may find many exciting revelations or anomalies within the dataset just by a glance. This dataset has 25000 different labeled images of buildings, forests, glaciers, mountains, seas, and streets. We will classify the images into these 6 distinct classes. Most of the images are of size 150 X 150 pixels.

Initial Setup

Let's start now by importing all the necessary libraries.

Opendatasets is an open-source library to get online datasets to your notebook easily.

od.download('https://www.kaggle.com/puneet6060/intel-image-classification')
Distribution of Images

This is the distribution of our images across all its classes. All images have their own class directories and the number of images is distributed unevenly.

Now we need to load our data and use some transformations on the image data to make it more useful for training our model. You would find many image data transformations in PyTorch however we will be using random crop and random horizontal flip which work well with this dataset. We will resize all images to 128 X 128 pixels and normalize all the pixel values with the imagenet pre-trained stats. I recommend that we load the images in batches of 128 so that training is completed in the available computing power.

With this done we have successfully loaded our training and validation data through Pytorch Dataloaders and each image will now be a bunch of pixel values normalized with the imagenet pre-trained stats.

GPU Utilities

Source

Convolution Neural Nets require a lot of computing power to train as there are a lot of weight calculations with each pixel value of all images. It might take hours or days to train models if we work with a CPU. But thankfully the online platforms such as Kaggle or Google Colab have a free GPU setup for us which improves our training time drastically. Let’s create some useful code to push our data to GPU.

We have now created a Python class DeviceLoader which will check if GPU is available to us within the environment and move the data to a GPU.

Train our model

Now we are all set to start coding for our model training. We will be using a pre-trained model ResNet34 which has been trained on the image data having 1000 classes. Why use a pre-trained CNN model? The initial layers of a CNN train on only low-level and mid-level features such as edges, lines, borders, etc. All kinds of images contain these features in them. These characteristics of a pre-trained CNN makes it very reusable. Hence it makes sense to use such pre-trained models that have been already trained on a large set of data, for which many companies have invested a lot of money. This is called Transfer Learning.

Source: How a CNN learns
Source: Layer visualization

To use a pre-trained model we need to keep all the previous layers as is and change only the final layer according to our use case. ResNet34 has been trained on 1000 image classes. Our problem has only 6 different image classes (buildings, forests, glaciers, mountains, seas, and streets). Hence, we will modify the last layer of ResNet34 to these 6 classes. Transfer Learning saves a lot of training time and development effort of the engineers.

Source: Transfer Learning

Pytorch has a nn.Module class which can be inherited to make our model class. We will write some reusable class ImageClassificationBase which consists of some helper functions which can be used for any model. Then we will create our model class MyModel which inherits the ImageClassificationBase class and we will apply the Transfer learning with ResNet34 there.

The last layer of our model has now been changed from the initial 1000 classes to our 6 classes. We could add more layers to our model if required at the end, but one linear layer works fine here.

self.network.fc = nn.Linear(self.network.fc.in_features, num_classes)

Hurray! We have now successfully created our CNN model using Transfer Learning. Further, we need to create a training loop where we evaluate our model, change hyperparameters such as learning rate, epochs, set up optimizers, etc. for training as well as validation. We will add some regularization techniques such as One Cycle Learning Rate Policy, Weight decay, Gradient clipping in our fit function to improve our model training.

With this, we have everything ready for our training. Now we just need to apply the function we developed using different hyperparameters and train our model to its best. After a few iterations with different hyperparameters, the model trains up to an accuracy of above 90% on the validation data within 10-15 mins. This is the advantage of applying Transfer Learning to your model and process it on a GPU to reach good accuracy so quickly. This approach comes very handy for industry applications where waiting long for training is not an option.

Let's take a look at our accuracies, learning rates, and losses w.r.t no. of epochs used for training and validation sets. Data visualization always helps to understand data more effectively.

Rise in Accuracy

Interestingly the accuracy was very low at the beginning because of random weights but it went up to 70% only in the second epoch itself. The accuracy eventually improved to 92% before becoming flat.

The validation loss shows some ups and downs in the beginning but towards the end, it starts to merge with training loss. If it diverges away from the training loss then that would imply Overfitting of the model.

The regularization technique used by us viz. One Cycle Learning Rate Policy increases the learning rate for the initial batches of data up to a point and then drops down for the later batches.

Source

There’s one last thing left. We need to see how our model works with unlabelled inference data. The Kaggle dataset used by us here already has a folder with unlabelled images. Let’s predict! We will randomly select images from that Inference set and use them for our prediction.

Our model was able to predict all classes of images from the Inference set.

Congratulations! We have now successfully built a CNN model using Transfer Learning with the PyTorch library. With this approach, any Multi-class Image Classification problem can be tackled with good accuracy in a short span of time.

--

--

Aneesh Dalvi
Walmart Global Tech Blog

Data Engineer 3, Data Strategy & Insights, Walmart Global Tech