How to deal with Unbalanced Image Datasets in less than 20 lines of code

ImageDataGenerator + RandomOverSampling

Arnaldo Gualberto
Analytics Vidhya
4 min readAug 30, 2019

--

As far as we know, in most cases, Deep Learning requires a large dataset to learn a specific problem. However, collecting such amount of data may be hard and expensive, especially for real-world problems. To deal with it, we can apply a well-known technique called Data Augmentation, which uses different random transformations to increase the dataset and improve variance.

On the other hand, data augmentation does not affect the distribution of labels in the original dataset. It means that if you have unbalanced data, your data will continue unbalanced after data augmentation. In all cases, the algorithm will be biased to predict only the most-frequent class(es).

So, I was wondering: is there a way to apply selective data augmentation to balance the dataset? 🤔 That’s what we’re going to find out.

Random OverSampling

There is a lot of techniques to deal with unbalanced data. One of them is oversampling, which consists of re-sampling less frequent samples to adjust their amount in comparison with predominant samples. Although the idea is simple, implementing it is not so easy. Luckily, there’s a python module for that!

The imbalanced-learn is a python package offering several re-sampling techniques commonly used in datasets showing strong between-class imbalance. It is compatible with scikit-learn and is part of scikit-learn-contrib projects. You can check their code or the documentation to install imbalanced-learn.

So, if we have a module to oversample and the Keras’ ImageDataGeneration class to augment images, why not mix them together to balance datasets?

That’s what I did!

The Code

Update (2020–08–09): bug fixes in BalancedDataGenerator. Thanks to Sébastien Richoz!

I wrote a custom ImageDataGenerator class that will generate images on-the-fly to balance a given dataset in each epoch. My class ensures that, given enough steps per epoch, the number of samples per class will follow a uniform distribution regardless of your batch size. All of it is done in less than 20 lines of code! Check it out:

Link to full code: https://gist.github.com/arnaldog12/16efc663c869b35e2479bd607d56c1da

Ok, but how do I use it? Basically, you’ll only have to do 3 things:

  • Define your data augmentation transformations as usual;
  • Instantiate an object of BalanceDataGenerator;
  • Retrieve the number of steps_by_epoch to feed to fit_generator method.

In code, it becomes 3 lines of code:

datagen = ImageDataGenerator(...) # define your data augmentationbgen = BalancedDataGenerator(x, y, datagen, batch_size=32)
steps_per_epoch = bgen.steps_per_epoch

If you want to confirm that my class does the job it's supposed to do, you can use the following code:

y_gen = [bgen.__getitem__(0)[1] for i in range(steps_per_epoch)]
print(np.unique(y_gen, return_counts=True))

An end-to-end example

Here's an example of a full code to train a model using the proposed class:

x, y = ... # load your datadatagen = ImageDataGenerator()
balanced_gen = BalancedDataGenerator(x, y, datagen, batch_size=32)
steps_per_epoch = balanced_gen.steps_per_epoch
model = ... # define your model
model.compile(...) # define your compile parameters
model.fit_generator(balanced_gen, steps_per_epoch, ...)

Easy, isnt’t it?

Remember that you can also define different generators for train and validation datasets. In this case, you can instantiate two objects of BalancedDataGenerator, one for each dataset.

Final Considerations

My custom class is still a work in progress, and I have already some ideas to improve it. The simplest thing would be adding support to other re-sampling techniques available at imbalanced-learn, including undersampling. Maybe a single parameter at __init__ method would be enough. Second, it would be nice to have a method to work with images from a directory — like flow_from_directory in Keras. In this case, we would not need to have the dataset previously loaded in memory.

However, I believe my class might help other people with unbalanced datasets as it is. By using it, I was able to improve the accuracy of a model in a high unbalanced dataset (+100 classes) from 63% to 97%. Give a try! It is worth, isn’t it? 😉

--

--