Augmentation for Image Classification
One of the issues one comes across while dealing with image data is the inconsistency in images (some are either too big or small, some are rectangular instead of square, etc). Another frequently faced problem is the number of images in the training set which often results in overfitting. To deal with these issues, I outline a technique that uses augmentation transforms — the images in the training set are transformed so as to increase the ability of the model to recognize different versions of an image. This increases the breadth of information the model has. It now becomes better suited to recognize target objects in images of varied contrast, size, from changed angles and so on.
To show how augmentation works, we look at the Dogs vs Cats dataset and make use of the deep learning library fast.ai by Jeremy Howard and Rachel Thomas. This post is inspired from fast.ai Deep Learning part 1 v2 course.
I. Baseline Model
To classify images as either dog or cat, we use resnet34 to train the model (more about ResNet architecture in this awesome blog by Apil Tamang!). We first train the model without data augmentation using learning rate 0.03 and 1 epoch.
With this we see a validation accuracy of
Here’s a look at the confusion matrix:
Thus we see that 26 images — 20 cats and 6 dogs out of 2000 have been misclassified.
II. Applying Augmentation Tranfsorms
To reduce this misclassification error, we now augment the train data and see if there is an improvement. We can either choose from top-down transformations or side-on transformations. Here’s a quick look at what the types involved:
- basic transforms — Changes in angle(rotation) and lighting
- side-on transforms — Changes in angle and lighting + flipping about the vertical axis
- top-down transforms — Changes in angle and lighting + flipping about the horizontal axis and rotating by 90, 180, 270 degrees.
Here we use a side-on transformation because given that we have pictures of dogs and cats that are taken from the side (as opposed to from the top), they possibly need to be just flipped horizontally, rather than vertically. Here’s a look at 6 random side-on transformations on a cat image:
Top down transformations are not appropriate here due to the nature of the images — upside down images of a cat or dog are rare!
While training this network, the learning rate is kept the same in order to see the difference in accuracy only due to augmentation. When we use augmentation to train the network, for every epoch a new transformation of every image is generated. Thus the model sees the same number of images in every epoch (as many as there are in the original training data), albeit a new version of those images each time. Thus, the range of images the model has seen increases with every epoch.
Networks such as ResNets are pretrained, i.e., the architecture comes with its set of precomputed weights for every layer except the fully connected ones. While training the model earlier, we used ResNet34’s precomputed weights. But this time, since we use new sets of training images, we set
Precompute = False to ensure that the model computes activations from scratch for the new model. Without doing this, we will still be using the precomputed activations that correspond to the original training data, not giving us much improvement in accuracy.
Now we get a validation accuracy of
0.98779 which is an improvement over the previous model.
From the confusion matrix above, we see that we have misclassified 22 images this time; a reduction in error as compared to before. Thus, we have increased the prediction power of the model.
III. Test Time Augmentation:
While augmentation helped give us a better model, prediction accuracy can be further improved by what is called Test Time Augmentation(TTA). To understand why this is needed let us first take a look at some of the misclassified images:
We see here that a few of the images have been misclassified due to poor contrast, rectangular rather than square images, or because the dog/cat is in a very small portion of the image. Take a look, for example, at the rectangular image of the dog. When the model tries to predict for this image, it sees just the center of the image (cropping by default is center). Thus it cannot predict if the image is of a dog or a cat.
To mitigate errors such as these we use TTA wherein we predict class for the original test image along with 4 random tranforms of the same image. We then take an average of the predictions to determine which class the image belongs to. The intuition behind this is that even if the test image is not too easy to make a prediction, the transformations change it such that the model has higher chances of capturing the dog/cat shape and predicting accordingly.
On using TTA, we now get a validation accuracy of
99.199 with just 16 misclassified images.
Although we can see that the baseline ResNet34 model fits the data well giving pretty good results, applying augmentation transforms and TTA reduce misclassification errors, improving the accuracy of the model. What remains to be seen is how augmentation can also be applied to structured data to boost accuracy, just like it does in case of unstructured data.