Image data augmentation to balance dataset in classification tasks
Try an image classification model with an unbalanced dataset, and improve its accuracy through data augmentation techniques.
We will create an image classification model from a minimal and unbalanced data set, then use data augmentation techniques to balance and compare the results.
The dataset
Our dataset has 200 flower images and 20 bird images in a 1:10 ratio. To form this dataset we used techniques to download image URLs through the Google® search engine, as described step by step in last article.
Downloading and checking the images
Once we have the list of the URLs for each category in our CSV files, we will run the code to download the photos and build our data set.
Path ‘data/data_aug’ is the base directory for us. In that directory, we have placed the two CSV files of the categories; let’s execute the code to create the subdirectories where the images will be downloaded and verify them.
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.vision import *classes = ['birds','flowers']
path = Path('data/data_aug')#creating folders
for folder in classes:
dest = path/folder
dest.mkdir(parents=True, exist_ok=True)#downloading images from urls in csv
file = 'urls_'+'birds''.csv'
dest = path/'birds'
download_images(path/file, dest, max_pics=20)file = 'urls_'+'flowers''.csv'
dest = path/'flowers'
download_images(path/file, dest, max_pics=200)#verifying images
for c in classes:
print(c)
verify_images(path/c, delete=True, max_size=500)
Creating and visualizing the dataset
After having the photos downloaded and checked in the directories corresponding to each category, we can create a fast.ai DataFrame to be able to put inside the tagged images and start visualizing and working with them. We reserve 20% for validate set.
np.random.seed(7)data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=get_transforms(), size=224, num_workers=4).normalize(imagenet_stats)
And run the code to display a random 3-row batch:
data_gt.show_batch(rows=3, figsize=(7,8))
The classification model
For this exercise, we are going to use a convolutional network, of resnet34 format¹.
Within the “models”, there is a set of pre-defined network architectures, involving different structures and complexities².
learn_gt = cnn_learner(data_gt, models.resnet34, metrics=error_rate)learn_gt.fit_one_cycle(4)
learn_gt.save('gt_stage-1')
learn_gt.load('gt_stage-1')
learn_gt.unfreeze()learn_gt.lr_find()
learn_gt.fit_one_cycle(2, max_lr=slice(1e-5,1e-2))
learn_gt.save('gt_stage-2')
learn_gt.load('gt_stage-2')
Results in the confusion matrix
interp = ClassificationInterpretation.from_learner(learn_gt)
interp.plot_confusion_matrix()
As we can see the model is very ineffective in predicting the class
Less represented class (birds) where there is a lot of confusion about the validation set.
Data augmentation
We must create 180 new images in the bird’s category. To do this, we’ll loop through each of the real images creating ten additional pictures for each one using fast.ai’s apply_tfms method from fast.ai.
path = Path('data/data_aug')
path_hr = path/'birds'il = ImageList.from_folder(path_hr)
tfms = get_transforms(max_rotate=25)def data_aug_one(ex_img,prox,qnt):
for lop in range(0,qnt):
image_name = str(prox).zfill(8) +'.jpg'
dest = path_hr/image_name
prox = prox + 1
new_img = open_image(ex_img)
new_img_fin = new_img.apply_tfms(tfms[0], new_img, xtra={tfms[1][0].tfm: {"size": 224}}, size=224)
new_img_fin.save(dest)prox = 20
qnt = 10
for imagen in il.items:
data_aug_one(imagen,prox,qnt)
prox = prox + qnt
If we visualize any of the source images and their ten new images, we’ll find things like this:
Details about “apply_tfms” transform function work and all its possibilities can found here.
The same model with balanced data
We created a new model with the balanced data sets
np.random.seed(7)tfms = get_transforms()
data = ImageDataBunch.from_folder(path, train=".", valid_pct=0.2, ds_tfms=tfms, size=224, num_workers=4).normalize(imagenet_stats)learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()
Okay, now let’s present a new images set
We took twenty new images, ten from each category, and we see that the accuracy, although much improved from the initial version, still presents some problems when categorizing. Out of 20 new images, the model predicted 18 correctly and two wrong:
path = Path('data/data_aug_test')
defaults.device = torch.device('cpu')
img = open_image(path/'bird01.jpg')
img
pred_class,pred_idx,outputs = learn.predict(img)
pred_class
Category birds
for test in range(1,11):
image_name = 'bird'+str(test).zfill(2)+'.jpg'
img = open_image(path/image_name)
pred_class,pred_idx,outputs = learn.predict(img)
print ('For image ' + image_name + ' predicted class: ');
print (pred_class) image_name = 'flower'+str(test).zfill(2)+'.jpg'
img = open_image(path/image_name)
pred_class,pred_idx,outputs = learn.predict(img)
print ('For image ' + image_name + ' predicted class: ');
print (pred_class)
The image flower07.jpg predicted as a bird, and the image bird08.jpg predicted as a flower.
Here the confusion images:
Summary
As we could see, the difference of training the same model with a balanced dataset and an unbalanced one is fundamental; however, if the dataset is minimal, it may be that the accuracy we reach is not enough.
In a previous article, we created a very similar model from 200 images of each category. When we presented him with the same 20 new test images, he had the same confusion in the same photos.
We used the apply_tfms method from fast.ai to perform the transformations of the few initial images of birds, creating a data set 10 times larger.
Sources and references
[1] — https://towardsdatascience.com/an-overview-of-resnet-and-its-variants-5281e2f56035
Into fast.ai course referred, this technique is inspired by: Francisco Ingham and Jeremy Howard / [Adrian Rosebrock]
https://course.fast.ai, lesson 2, Jupyter Notebook: lesson2-download