Flower Classification using Transfer Learning and CNN (Step-by-Step)

MichaelChiaYin
6 min readJul 1, 2020

So today our objective is to see the difference in the value of accuracy using a different method to train our Image Classification model:

  1. Basic Convolutional Neural Network
  2. Transfer learning using Resnet 50
  3. Transfer learning using Resnet 18

Feel free to check out my notebook: https://jovian.ml/edsenmichaelcy/flower-classification/v/12

Contents

Step 1: Import the file we needed and put the dataset into data_dir
Step 2: Data augmentation & normalization
Step 3: Check the dataset classes and label them
Step 4: Functions to show a single picture and batch picture
Step 5:Split the training data and the validity data
Step 6: Choose the batch size, put in DataLoader and show the batch
Step 7: Get GPU up on running
Step 8: Training the Image Classification using basic CNN
Step 9:Training and Validation Datasets
Step 10: Training the model with CNN
Step 11: Predict and test the model
Step Resnet50: Transfer Learning method using Resnet50 (Pre-trained)
Step Resnet18: Transfer Learning method using Resnet18 (Pre-trained)
Final Step: Conclusion comparing the result

Flower Classification

Today we are going to do flower classification. There are many types of flowers in the world and mostly botanist scientists will need to require some knowledge to able to recognize the type of flower. So now we will use deep learning knowledge to help botanist scientists to identify the type of flower. Let get started!

Step 1: Import the file we needed and put the dataset into data_dir

Michael Chia Yin
Michael Chia Yin

Step 2. Data augmentation & normalization

Michael Chia Yin

To understand what is Data augmentation and normalization. First, we must understand why do we need it for? The reason to have data augmentation and normalization is to prevent the model to be overfitting. Moreover, normalization help neural net to ensure the input data always is within a certain numeric boundary. For example, Channel-wise data normalization is one of the ways to normalization the image color as you can see in the code mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]. In addition, randomized data augmentations are one of the ways for data augmentation. For example, we can apply RandomHorizontalFlip and RandomVerticalFlip to help the image to flip horizontal and vertical.

You may ask what is overfitting? Overfitting is whereby a model start to specialize in this model and not generalize. We must always train our model in generalizing and not specialize.

Step 3. Check the dataset classes and label them

Michael Chia Yin

In this dataset, we can clearly see they are 5 classes that are daisy, dandelion, rose, sunflower, and tulip.

Michael Chia Yin

We will need to label 0–4 for the {daisy, dandelion, rose, sunflower, and tulip}. We will use this label for the prediction functions.

Step 4 Functions to show a single picture and batch picture

Michael Chia Yin
Michael Chia Yin

Step 5 Split the training data and the validity data

Michael Chia Yin

We will use 500 as our validation size and for the training size, we will use the length of the dataset as 4323 and minus 500 to get the training size.

Michael Chia Yin

Afterwards, we will use the functions random_split to split the data randomly. As you can see the length of train_ds is 3823 and val_ds is 300

Step 6 Choose the batch size, put in DataLoader and show the batch

Michael Chia Yin

We are going to use batch size of 32 to train our data.

Michael Chia Yin

Here we will start to put all the train_ds, val_ds, and test_ds into Dataloader to start our training.

Michael Chia Yin

When we had our batch size and load the data into the data loader than we will show the batch of the picture where there are 32 pictures.

Step 7 Get GPU up on running

Michael Chia Yin

Here is the function to get GPU up on running

Michael Chia Yin

As you can see here, the device type is Cuda means the GPU is up running.

Step 8 Training the Image Classification using basic CNN

Michael Chia Yin

Here is where we define the skeleton of our deep learning model using the class that inherited from PyTorch class nn.Module.

Michael Chia Yin

Here is where we define the CNN model and how we activate the function. As you can see the nn.Conv2d is a 2D Convolution Layer and we can define as many we want. In addition whenever we use nn.Conv2d we must use nn.ReLU() to activate the function of nn.Conv2d.

Step 9 Training and Validation Datasets

  1. Training set — used to train the model i.e. to compute the loss and adjust the weights of the model using gradient descent.
  2. Validation set — used to evaluate the model while training, adjust hyperparameters (learning rate, etc.), and pick the best version of the model.
  3. Test set — used to compare different models, or different types of modeling approaches, and report the final accuracy of the model.
Michael Chia Yin

when we get the GPU running we need to put all the train_dl and val_dl into the device data loader.

Michael Chia Yin

After defining all the functions in the FlowerModel() then we need to put the functions into the model so we will be able to train.

Step 10 Training the model with CNN

Michael Chia Yin
Michael Chia Yin

We will use 10 epochs, torch.optim.Adam as our functions optimizer and the learning rate of 0.001

Michael Chia Yin

You will be able to see the val_acc of using basic CNN is only 67%

Step 11 Predict and test the model

Michael Chia Yin

Here you can see the increase of the accuracies when the epoch increase also.

Michael Chia Yin

It is good to see the validation is decreasing following the training also. In addition, if you see the validation chart start to increase at some point this shows that the model is starting overfitting.

Michael Chia Yin

As you can see the model is not accurate enough to predict the right answer.

Step Resnet50: Transfer Learning method using Resnet50 (Pre-trained)

We will use back the same step for defining the skeleton of our deep learning on the “Step 8" but will chance the item inside the FlowerModel() to FlowerModelResnet50()

Michael Chia Yin

Here we will use the pre-trained resnet50 to train our model and hope is able to get a better result.

Michael Chia Yin

We will use back the same epochs, optimizer and the learning rate and see what val_acc we can get.

Michael Chia Yin

The val_acc that are using transfer learning (resnet50) score a bit better than basic CNN, that score around 73%

Michael Chia Yin

As you can see the accuracy increase by more epochs is taken.

Michael Chia Yin

This model is bad because as you can see the Validation increase at some point in training. This means the model is starting memories of the data rather than learning it and this phenomenon is called overfitting.

Michael Chia Yin

The prediction is more accurate than the basic CNN model. Now let try another pre-trained model that is resnet18.

Step Resnet18: Transfer Learning method using Resnet18 (Pre-trained)

We will use back the same step for defining the skeleton of our deep learning on the “Step 8” but will change the item inside the FlowerModel() to FlowerModel18

Michael Chia Yin

Here we will use the pre-trained resnet18 to train our model and let see the final result.

Michael Chia Yin

This time we will use a small number of epochs to achieve high val_acc.

Michael Chia Yin

The result shows promising! The val_acc reached 85% and that is better than resnet50 and basic CNN.

Michael chia Yin

As you can see using the prediction is very accurate.

Conclusion comparing the result

After training the data using basic Cnn, Transfer learning(resnet50 & resnet18) for two weeks I found that the ranges score of different models:

Michael Chia Yin

This is using basic CNN. The range of val_acc of this on model is around 50%- 65%.

Michael Chia Yin

By using the transfer model (resnet50)method. The range of the val_acc is 60%-75%.

Michael Chia Yin

By using the transfer model (resnet18)method. The range of the val_acc is 70%-89%.

Reference Links

Classifying CIFAR10 images using ResNet and Regularization techniques in PyTorch:https://jovian.ml/aakashns/05b-cifar10-resnet-live

A Comprehensive Guide to Transfer Learning:https://www.kaggle.com/rajmehra03/a-comprehensive-guide-to-transfer-learning

Dataset flower: https://www.kaggle.com/alxmamaev/flowers-recognition/kernels

Do check me out at Jovian, Linkedln, and Kaggle:

--

--