Building an Image Classification Model with Limited and Unbalanced Dataset via Contrastive Learning

Natthasit Wongsirikul
8 min readJan 23, 2023

--

This post is a deep dive into how I created a classification model to predict the state that a crane is in. The detail of the problem statement can be found in my other post, but for now think of this as a 3 classes image classification problem.

A quick recap, I want to monitor crane usage, and I need to know what of the 3 states the crane is in (mobile, outrigger extended, open boom)

Left (mobile state), middle (extended outrigger), right (open boom)

Dataset

Images of the crane were sourced and grouped into 3 classes {mobile, extended-outrigger, and open-boom}. The total number of datasets is summarized below.

Distribution of classes to train classifier. Little in number and unbalanced.

There was a problem with class imbalance, where the majority of the dataset was made up of the open-boom class. This was because the crane would enter the scene if it was being used, where it would be in the open-boom state most of the time. Examples of mobile and extended-outrigger state were available only during the short period of time where the crane entered the frame to go to the work spot. Class imbalance would have a negative effect on supervised learning.

Multi-Class Image Classification via Transfer Learning

The first thing I tried for any image classification task, is to do transfer learning with the most state-of-the art model at the time, EfficientNet-B7. I loaded model with pre-trained weights on ImageNet while freezing the backbone layer but allowing the final few layers to update weights. However, the result wasn’t good. It seemed like the model wasn’t able to extract features that can differentiate the 3 crane states well. Adding more layers to the model didn’t help and ran the risk of overfitting.

Training loss goes down but validation accuracy doesn’t increase.

Contrastive Learning with Triplet Loss Network

Faced with limited dataset and unsatisfying transfer learning result, I looked for a different approach. I started exploring one-shot learning and face recognition problems which relies on contrastive learning. Most well-known architecture would be the Siamese network and the triplet loss network, where an anchor is selected and compared to a positive (same class) example and a negative (different class) example. One advantage of one-shot learning is it doesn’t require a lot training data, an issue that I have.

Triple Loss
Encoder model trained using triplet loss

The key idea is the model learns to embedded feature descriptor from images and cluster member of the same class close together while pushing member of different class as far apart as possible. Below is an example of how such model being used to tackle the classic MNIST problem.

2D Vector representation of MNIST digits encoded by model

Batch Triplet Selection Strategy

For the training setup, I used online triplet selection where for each batch, the dataloader will randomly select n samples images for each class, then an anchor will be chosen along with a negative and positive pairs. Below is an illustration of a possible batch where the L2 distance between the example’s embedding and the anchor’s embedding has been calculated.

Anchor image (blue box) is of class A therefore the list of positive images (green box) should also be class A. The remaining classes B&C are of negative images (red box) The decimal numbers underneath each image is L2 distance computed against the anchor image (blue box). The smaller the number the more alike the two images are.

The bigger this number is the larger the difference between the images. Remember that it is desirable for the L2 distance between the positive and anchor to be as close to zero as possible while the as negative as large as possible. Note that the L2 distance numbers below are made up arbitrary for the purpose of illustrating the concept only.

The loss function is defined as below.

loss function in contrastive learning (triple loss) with margin where a is anchor, p is positive, n is negative. The d() is L2 distance calculation

Below is an illustration of a triplet that are easy to learn or the model has already optimized it.

Pair selection for training that is consider “Easy” because the L2 distance between positive image and the anchor (0.2) is already smaller than the negative image (2.3)

Below is an illustration of a triplet that is harder to learn.

Pair selection for training that is consider “Harder” because the L2 distance between positive image and the anchor (0.5) is larger than the negative image (0.3)

There are several selection strategies to pick which triplet from the batch to be trained. First the triplet loss for all the examples within the batch against the anchor were calculated then ordered from lowest to highest. After that the ordered loss was separated into 3 bins: hard negatives, semi-hard negatives, and easy negatives. Depending on the selection strategy, one can choose the triplet from hard negative, semi-hard negative, or randomly. Below is an illustration of how each triplet are organized, the blue circle represents a triplet (3 images: anchor, positive, and negative) and the number is their loss combined.

For my selection strategy, I only choose triplet from the hard negative bin. Within this set, I tried two approaches. First, I tried hardest negative which meant always picking the highest triplet loss. The second approach, which worked better, was randomly selecting any triplet within this bin. I found that with hardest triplet loss, training was fixated on a small sample of images (hardest to learn) and thus no generalized learning occurred. I also went with online triplet mining instead of offline because it was more efficient.

Model Architecture

As for the model, I designed my own embedding model where I experiment with different model hyperparameter such as model depth and width, as well as experimented with different famous architecture submodule such as the skip-layers from ResNet or the chaining of convolution layers from DenseNet. In the end, I found that multi convolution kernel inspired by Inception-V3 model was the most successful submodule feature extractor. Below is the final model architecture.

Model architecture where I burrow concepts from the Inception model with parallel multi kernel convolutional operation which create

The two key features were the first two modules. The first module consisted of 4 different kernel size 2D convolution for extracting features at different sizes, after which the outputs were concatenated together. The second module was inspired by the InceptionV3 module. The 1x1 depth-wise separable convolution layer was designed to minimize the number of parameters within the model. The remaining module followed the traditional Conv, BatchNorm, ReLU, and MaxPool stack. Global average pooling was used to spatially reduce feature map into feature vector before fully connected layer. The output of the model is a 2-dimension vector for easy visualization of the embedding space. The model size is also small containing only 2,243,618 parameters. Compared to other state-of-the-art models, it is about the same as MobilenetV2.

Comparison of model size where Custom is the model.

Training Method

I split the dataset up into train and test set where they are separated by the month the frames were captured from. The training set consisted of a small sample of images collected over the months designated for training. I manually selected them by making sure that I cover all possible crane orientations as well as all crane brand type. The remaining images were combined with the images from the test month to create a large test set.

Small samples of images are collected from 3 months (designated in blue as training dataset) while the remaining images are combined with the test-set month (green)

The dataset is summarized in the table below as well as some sample images from the train set

Dataset summary. 3 classes: Mobile, Extended Outrigger, and Open Boom.

The table below are sample images for each class from the training set.

Example images from 3 classes.

I trained the model using Adam optimizer with a stepped learning rate scheduler. I set the margin to 1. The number of sample images for each class from a batch was 5. I trained the model on 500 epochs. For evaluation metric, I counted the average number of non-zero triplets found in a batch. The less the number of non-zero triplet per batch the better because that means the model is producing more intra-class embedding that are closer to each other than inter-class. To put it another way, as training continues the number of triplets with loss of 0 (triplets belonging to easy negatives) should increase.

Results

The loss and the evaluation metric for the train and test set are shown below.

training logs

Below are the output embedding of the whole training dataset set. Here A is mobile, BC is extended outrigger, and D is open boom.

A is mobile, BC is extended outrigger, and D is open boom. Training dataset

The result showed that there seem to be a distinct cluster for each class. To make this into a classifier, I used Support Vector Machine (SVM) to find the optimal hyperplane (in this case just plane because it’s only 2D) to separate region by class.

SVM plane separation on the training dataset.

Below is a plot of the prediction on the test set overlaid on the SVM plane calculated from the training set.

Encoding of the test-set
SVM plane separation on the test-set.

The classifier performance on the test set is summarized in the table below along with the confusion matrix.

Recall, Precision, and F1 for the 3-classes classifier model.
Confusion matrix of prediction on the test-set. GT is ground truth. M is class mobile, E is extended-outrigger, and O is open boom.

The result was pretty good, the classifier demonstrated the ability to extract discriminating features and map them to an embedding space that was separable. If I added just a couple of examples from the test months, the model performance improved dramatically.

Leaking a couple of images from the test-set drastically improve the model performance.

Error Analysis

Let’s take a look at some of the misclassification images

Predicted Open-Boomed but should be Extended Outrigger. Here there are images where it seems like the boom is partially lifted, along with different perspective of the crane which may contribute to this misclassification.

GT = Extended Outrigger but Prediction = Open-Boom

Predicted Mobile but should be Extended Outrigger. It can be seen that the misclassification images are from views are the outrigger are not clear visible yet. For examples, when crane is starting to extend the outrigger, or when the crane orientation blocks the view of the outrigger, or the outriggers are not clear visible due to occlusion or shadows.

GT = Extended Outrigger but Prediction = Mobile

Predicted Mobile but should be Open Boom. These misclassifications may be due to different perspective of the crane. When the images are cropped, sometimes, without surrounding context, it can be hard to tell if crane is flat on the ground plane but looks erected. Another possible source of misclassification is when the crane’s body is occluded which may confuse the model.

GT = Open boom but prediction = Mobile

--

--

Natthasit Wongsirikul

I'm a computer vision engineer. My interest span from UAV imaging to AI CCTV applications