How to deal with Unbalanced Image Datasets in less than 20 lines of code
ImageDataGenerator + RandomOverSampling
As far as we know, in most cases, Deep Learning requires a large dataset to learn a specific problem. However, collecting such amount of data may be hard and expensive, especially for real-world problems. To deal with it, we can apply a well-known technique called Data Augmentation, which uses different random transformations to increase the dataset and improve variance.
On the other hand, data augmentation does not affect the distribution of labels in the original dataset. It means that if you have unbalanced data, your data will continue unbalanced after data augmentation. In all cases, the algorithm will be biased to predict only the most-frequent class(es).
So, I was wondering: is there a way to apply selective data augmentation to balance the dataset? 🤔 That’s what we’re going to find out.
Random OverSampling
There is a lot of techniques to deal with unbalanced data. One of them is oversampling, which consists of re-sampling less frequent samples to adjust their amount in comparison with predominant samples. Although the idea is simple, implementing it is not so easy. Luckily, there’s a python module for that!
The imbalanced-learn is a python package offering several re-sampling techniques commonly used in datasets showing strong between-class imbalance. It is compatible with scikit-learn and is part of scikit-learn-contrib projects. You can check their code or the documentation to install imbalanced-learn.
So, if we have a module to oversample and the Keras’ ImageDataGeneration class to augment images, why not mix them together to balance datasets?
That’s what I did!
The Code
Update (2020–08–09): bug fixes in BalancedDataGenerator. Thanks to Sébastien Richoz!
I wrote a custom ImageDataGenerator class that will generate images on-the-fly to balance a given dataset in each epoch. My class ensures that, given enough steps per epoch, the number of samples per class will follow a uniform distribution regardless of your batch size. All of it is done in less than 20 lines of code! Check it out:
Ok, but how do I use it? Basically, you’ll only have to do 3 things:
- Define your data augmentation transformations as usual;
- Instantiate an object of
BalanceDataGenerator
; - Retrieve the number of
steps_by_epoch
to feed tofit_generator
method.
In code, it becomes 3 lines of code:
datagen = ImageDataGenerator(...) # define your data augmentationbgen = BalancedDataGenerator(x, y, datagen, batch_size=32)
steps_per_epoch = bgen.steps_per_epoch
If you want to confirm that my class does the job it's supposed to do, you can use the following code:
y_gen = [bgen.__getitem__(0)[1] for i in range(steps_per_epoch)]
print(np.unique(y_gen, return_counts=True))
An end-to-end example
Here's an example of a full code to train a model using the proposed class:
x, y = ... # load your datadatagen = ImageDataGenerator()
balanced_gen = BalancedDataGenerator(x, y, datagen, batch_size=32)
steps_per_epoch = balanced_gen.steps_per_epochmodel = ... # define your model
model.compile(...) # define your compile parametersmodel.fit_generator(balanced_gen, steps_per_epoch, ...)
Easy, isnt’t it?
Remember that you can also define different generators for train and validation datasets. In this case, you can instantiate two objects of BalancedDataGenerator
, one for each dataset.
Final Considerations
My custom class is still a work in progress, and I have already some ideas to improve it. The simplest thing would be adding support to other re-sampling techniques available at imbalanced-learn, including undersampling. Maybe a single parameter at __init__
method would be enough. Second, it would be nice to have a method to work with images from a directory — like flow_from_directory
in Keras. In this case, we would not need to have the dataset previously loaded in memory.
However, I believe my class might help other people with unbalanced datasets as it is. By using it, I was able to improve the accuracy of a model in a high unbalanced dataset (+100 classes) from 63% to 97%. Give a try! It is worth, isn’t it? 😉