Wild Animal Image classifier

Yogesh Gurjar
9 min readDec 1, 2022

--

This blog post will discuss how we can build an image classifier for classifying the animals in our dataset.

Dataset

The data is crucial to building a machine-learning model. So first, we need to understand our data. This dataset came from Kaggel named Wild Animals Images. This dataset contains pictures of various animals that we will categorize using our best classifier.

The dataset consists of the following 6 categories -

  • Cheeta
  • Fox
  • Hyena
  • Lion
  • Tiger
  • Wolf

Let’s look at a sample of our dataset -

Dataset sample

Data Preprocessing

Data Cleaning

When we build the pipelines to preprocess images I check it running on all images so I can detect images that are corrupted so it won’t affect during training.

I found two images in the dataset that are corrupted. I deleted the images.

Dataset Splits

We need to preprocess our data so we can input it into models to train our models. Before that, we need to build the following data splits -

  • Train Set — To train our model
  • Validation Set— To decide hyperparameters on basis of bias-variance trade-off
  • Test Set — To decide the real performance of our classifier on unseen data.

Transformation and Dataloader

We need different kinds of transformations to input our data into our models. These transformations can be

  • resize
  • reshape
  • grayscale
  • Conver to tensor etc.

We used torchvision transformation to achieve it.

After that, we need to use data loaders to make batches and shuffle the data if needed. We use data loaders, which are iterables, to gather our data in batches and feed it into our models.

Baseline Model

We start with a basic convolutional neural network. but what is the architecture of convolution neural networks -

The CNN architecture is made up of various layers, but we can say that there are two main components, which are as follows:

Feature Extraction: This section is primarily in charge of identifying visual features like different edges, shapes, or contours. It consists of convolution layers and pooling layers.

Classification: The main task of this section is to predict the class of the input images. It learns which feature combination will be responsible for which class. It consists of fully connected layers.

We can look understand the architecture of CNN by the following figure —

(Fig 1)CNN Architecture Source — [Source]

Our Model Architecture —

→Convolution( 3 filter size, 16 filters) → Max Pooling(2 filter size, 2 stride) →Convolution( 5 filter size, 32 filters) → Max Pooling(2 filter size, 2 stride) →Convolution( 5 filter size, 64 filters) → Max Pooling(2 filter size, 2 stride) → Flatten → Dense(9216,128) →Dense(128,6)

Parameters —

  • epochs 10
  • batch size 16
  • Optimizer — Adam
  • Learning rate — 0.001

Test Accuracy —

Only 58% accuracy is achieved, which is not very impressive. Still, our model is better than the random model, which can be 16.6%(100/6) accurate.

Problems -

  • We don’t have much data to learn parameters.
  • Our model is overfitting because our data don’t have many variations of our animal images.

As we can see our problem our next models will try to address these problems with the following method -

Data Augmentation -

So our model requires looking at more variations of our images, as we don't have more images. We can augment more data with the available data. Data augmentation provides regularization to our model and won’t overfit the training data.

"In data augmentation, we apply various transformations to already-existing data to create new data samples."

It doesn’t require us to collect more data, which is not feasible in cases like medical imaging. It is also useful when creating more data is very expensive (like any expensive manufacturing process).

In Data augmentation, we will apply various transformations like -

  • flipping our image
  • shearing of image
  • changes in hue, saturation, etc.

Let’s look at a sample of our images when we apply data augmentation to them.

Data Augmentation

CNN with Data augmentation

Now we will train our baseline model with our data augmentation. We didn’t change anything in our baseline model in terms of parameters or architecture.

Performance -

Yes, we improved our model's performance with data augmentation and we got 60% accuracy. But still, our model's performance is not that good.

Problems -

  • Our model is not that complex but a large model has more parameters to train.
  • Then we face the same problem more parameters need more data to train our model.

Now we try to use the following method to improve our model further without any changes in data. But How? let’s see -

Transfer Learning -

In human intelligence, we have this feature of transfer learning. We can use our prior knowledge about other things to learn new things, and this makes us very efficient when we try to learn new things.

For example, when babies have knowledge about shapes and colors they use that knowledge inherently to learn to identify new things like fruits, vegetables, etc.

So on the same lines researcher’s experimented when we train large models on a big available dataset and try to use that model as a feature extractor for training new models for other similar tasks. We can train those models accurately with a very small amount of data.

So Transfer learning is using knowledge of large pre-trained models on big datasets to train new models on new similar tasks with less data.

Image classification is a good problem to use transfer learning because of the following reasons -

  • Basic features like edge and contour detection can help in any image classification task.
  • Identifying shapes like circles, corners etc. also very useful knowledge to pass on to other models.
  • In CNN Figure 1 you can see our feature extractor part can be used to get good features without any training from our pre-trained models.
  • If we want to train all parameters and don’t want to rely on our pre-trained models. The pre-trained model provides good initial parameters to start our training.

CNN with Transfer Learning -

We used the VGG13 network as our pre-trained model; it is trained on the Imagenet dataset. The imagenet dataset has millions of images from 1,000 categories. So we can really rely on this pre-trained model, as it already has a good feature extractor for images.

We can look architecture of VGG13 in the following diagram -

VGG 13 architecture [Source](1)

Now, what do we need to change so we can train our classifier with this pre-trained model?

  • Load VGG 13 model with the pre-trained parameters(weights).
  • Replace the last layer from 1000 nodes fully connected layers to 6 nodes fully connected layers as we have only 6 classes and not 1000 classes.
  • Freeze the weights of all layers except the last layer we changed so our optimizer won’t update the weights of any other layer.

We can understand this by the following diagram -

the last layer changed to have 6 neurons

Parameters —

  • Batch size — 16
  • epochs — 5
  • LR — 1e-3

Performance —

We achieved 99.2% accuracy with this new model. It’s very impressive.

Now let’s analyze better: do we really need a neural network, or can we use the features from our pre-trained model and train classical ML models like KNN and SVM to achieve good performance?

KNN and SVM use VGG 4096 dimension features

SVM with VGG feature extractor -

Support-vector machine models are a very popular choice before neural networks, as they achieve good performance even for nonlinear boundaries and give generalized models.

In short, the SVM model is trying to build a classification boundary that keeps the maximum possible distance from both classes. To find complex non-linear boundaries, it employs kernel tricks to computer features from the transformation of existing features.So it's a good choice for both simple and complex problems.

Performance -

By using VGG as a feature extractor and training a straightforward SVM model on that, we are able to achieve 99.8% accuracy.

KNN with VGG feature extractor -

KNN, or K nearest neighbor is one of the simple models that classify the sample on the basis of the label of the K closest samples to that sample. here we can use different distance measures to find the closest distance like euclidean, cosine, manhattan, etc.

As training samples for our KNN model, we used features taken from VGG13.

Performance

By using VGG as a feature extractor and training a straightforward SVM model on that, we are able to achieve 98.8% accuracy.

Compare All Models

Performance —

First, we can analyze the performance of all our models to decide which one is best in terms of accuracy.

  • As we can see, CNN with the transfer learning(VGG_transfer_learning) model is the best model for us.
  • But the performance of KNN and SVM is also similar because of the good features from VGG.

Training Time -

We will compare how much time it took us to train all our models.

  • As we can see, KNN and SVM took very little time to train and gave us good accuracy.
  • VGG is a large model as compared to the baseline, but most of the weights are frozen and we ran it for only 5 epochs so it took less time.

Extras

MLflow

I think experimentation tracking is a tedious thing. When we are doing many experiments it is easy to get lost in which things worked for us and which did not. So I used this opensource tool in my experiments to track -

  • Parameters
  • Metrics
  • Loss
  • Model weights

I suggest using this tool not only because it helps to track and understand all of your experiments, but it’s also a good choice to reproduce your results which is a fair problem existing in ML experimentations.

all experiment runs summary
All details we logged during experiment runs

Build an App

I also developed a web app so we can deploy our models as a product where anyone can try to use it as a product and we can add value in real terms.

I used streamlet to build this app. I would suggest taking the following things into consideration when you create an app for your model -

  • Use the right version of the pickling libraries that you used to serialize and save your model.
  • Keep in mind to apply the same transformation on the data that you used to train your model otherwise results can be very bad.
  • Keep the interface easy and interactive to use.

Example

As you can see Our app tells the class of the image and the prediction probability for any image given by a user.

Contribution

  • Build the code for baseline CNN models and helper functions to train, validate, and test.
  • Write code to include MLflow to track experiments which can then be easily accessible by a web portal.
  • Did experiments on data augmentation, transfer learning, and using classing ml models with VGG 13 Model.
  • Build a web application so users can use our models on the images they want to classify.
  • Wrote a blog to explain why we use particular techniques.

Challenges-

  • While building the baseline model struggled with the convergence of the model and tried different optimizers, batch sizes, learning rates, and layers to get a baseline model.
  • Building the feature extractor for VGG got resistance on how to modify nn.sequential to keep some layers while removing other layers.

References -

  1. Figure 1 — Medium blog from Sai Balaji
  2. GitHub repository for the code
  3. Streamlit docs
  4. MLflow docs

--

--