Image data augmentation to balance dataset in classification tasks

Try an image classification model with an unbalanced dataset, and improve its accuracy through data augmentation techniques.

Gabriel Naya
Analytics Vidhya
5 min readDec 27, 2019

--

Photo by USGS on Unsplash

We will create an image classification model from a minimal and unbalanced data set, then use data augmentation techniques to balance and compare the results.

The dataset

Our dataset has 200 flower images and 20 bird images in a 1:10 ratio. To form this dataset we used techniques to download image URLs through the Google® search engine, as described step by step in last article.

Downloading and checking the images

Once we have the list of the URLs for each category in our CSV files, we will run the code to download the photos and build our data set.
Path ‘data/data_aug’ is the base directory for us. In that directory, we have placed the two CSV files of the categories; let’s execute the code to create the subdirectories where the images will be downloaded and verify them.

%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *
classes = ['birds','flowers']
path = Path('data/data_aug')
#creating folders
for folder in classes:
dest = path/folder
dest.mkdir(parents=True, exist_ok=True)
#downloading images from urls in csv
file = 'urls_'+'birds''.csv'
dest = path/'birds'
download_images(path/file, dest, max_pics=20)
file = 'urls_'+'flowers''.csv'
dest = path/'flowers'
download_images(path/file, dest, max_pics=200)
#verifying images
for c in classes:
print(c)
verify_images(path/c, delete=True, max_size=500)

Creating and visualizing the dataset

After having the photos downloaded and checked in the directories corresponding to each category, we can create a fast.ai DataFrame to be able to put inside the tagged images and start visualizing and working with them. We reserve 20% for validate set.

np.random.seed(7)data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)

And run the code to display a random 3-row batch:

data_gt.show_batch(rows=3, figsize=(7,8))

The classification model

For this exercise, we are going to use a convolutional network, of resnet34 format¹.
Within the “models”, there is a set of pre-defined network architectures, involving different structures and complexities².

learn_gt = cnn_learner(data_gt, models.resnet34, metrics=error_rate)learn_gt.fit_one_cycle(4)
learn_gt.save('gt_stage-1')
learn_gt.load('gt_stage-1')
learn_gt.unfreeze()
learn_gt.lr_find()
learn_gt.fit_one_cycle(2, max_lr=slice(1e-5,1e-2))
learn_gt.save('gt_stage-2')
learn_gt.load('gt_stage-2')

Results in the confusion matrix

interp = ClassificationInterpretation.from_learner(learn_gt)
interp.plot_confusion_matrix()

As we can see the model is very ineffective in predicting the class
Less represented class (birds) where there is a lot of confusion about the validation set.

Data augmentation

We must create 180 new images in the bird’s category. To do this, we’ll loop through each of the real images creating ten additional pictures for each one using fast.ai’s apply_tfms method from fast.ai.

path = Path('data/data_aug')
path_hr = path/'birds'
il = ImageList.from_folder(path_hr)
tfms = get_transforms(max_rotate=25)
def data_aug_one(ex_img,prox,qnt):
for lop in range(0,qnt):
image_name = str(prox).zfill(8) +'.jpg'
dest = path_hr/image_name
prox = prox + 1
new_img = open_image(ex_img)
new_img_fin = new_img.apply_tfms(tfms[0], new_img, xtra={tfms[1][0].tfm: {"size": 224}}, size=224)
new_img_fin.save(dest)
prox = 20
qnt = 10
for imagen in il.items:
data_aug_one(imagen,prox,qnt)
prox = prox + qnt

If we visualize any of the source images and their ten new images, we’ll find things like this:

On left original picture, on rigth ten transformed images

Details about “apply_tfms” transform function work and all its possibilities can found here.

The same model with balanced data

We created a new model with the balanced data sets

np.random.seed(7)tfms = get_transforms()
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=tfms, size=224, num_workers=4).normalize(imagenet_stats)
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()

Okay, now let’s present a new images set

We took twenty new images, ten from each category, and we see that the accuracy, although much improved from the initial version, still presents some problems when categorizing. Out of 20 new images, the model predicted 18 correctly and two wrong:

path = Path('data/data_aug_test')
defaults.device = torch.device('cpu')
img = open_image(path/'bird01.jpg')
img
pred_class,pred_idx,outputs = learn.predict(img)
pred_class

Category birds

for test in range(1,11):
image_name = 'bird'+str(test).zfill(2)+'.jpg'
img = open_image(path/image_name)
pred_class,pred_idx,outputs = learn.predict(img)
print ('For image ' + image_name + ' predicted class: ');
print (pred_class)
image_name = 'flower'+str(test).zfill(2)+'.jpg'
img = open_image(path/image_name)
pred_class,pred_idx,outputs = learn.predict(img)
print ('For image ' + image_name + ' predicted class: ');
print (pred_class)

The image flower07.jpg predicted as a bird, and the image bird08.jpg predicted as a flower.

Here the confusion images:

bird08.jpg
flower07.jpg

Summary

As we could see, the difference of training the same model with a balanced dataset and an unbalanced one is fundamental; however, if the dataset is minimal, it may be that the accuracy we reach is not enough.

In a previous article, we created a very similar model from 200 images of each category. When we presented him with the same 20 new test images, he had the same confusion in the same photos.

We used the apply_tfms method from fast.ai to perform the transformations of the few initial images of birds, creating a data set 10 times larger.

Sources and references

[1] — https://towardsdatascience.com/an-overview-of-resnet-and-its-variants-5281e2f56035

[2] — https://medium.com/@14prakash/understanding-and-implementing-architectures-of-resnet-and-resnext-for-state-of-the-art-image-cf51669e1624

Into fast.ai course referred, this technique is inspired by: Francisco Ingham and Jeremy Howard / [Adrian Rosebrock]

https://course.fast.ai, lesson 2, Jupyter Notebook: lesson2-download

--

--

Gabriel Naya
Analytics Vidhya

Machine learning enthusiast and research at Kreilabs Uruguay. My profile: https://gnaya73.glitch.me/