Solving Class Imbalance problem in CNN

A roundup of methods to tackle class imbalance

abhishek kushwaha
AI Graduate
5 min readJan 29, 2019

--

General representation of the problem of classifier trained with imbalance class

Life isn’t fair. There is unequal distribution of natural resources across countries, unequal distribution of money, unequal distribution of political powers and so on. Data too is rarely found fairly distributed among classes. In real world, one never gets data which is equally distributed among all classes (labels). We have to live with this situation and have to find ways to deal with it. This is not a new problem and statisticians and mathematicians have spent decades trying to solve this in different ways. Having done quite a lot of research and experiments on this topic, I think it is a good time to share my learning.

What is class Imbalance?

Before getting into the solution of class imbalance lets take a quick peek into what is class imbalance. A machine learning algorithm learns from labelled datasets. Neural networks are primarily used for classification tasks where the network learns by looking at data points belonging to different classes. Imagine you have a classification problem where you have to identify whether a picture shown to the network has a dog in it or a cat. Now assume that your training set has a total of 10,000 images 9,998 images are of dogs and there is only one image which has a cat and the remaining image has none of them. Do you think the neural network will perform well in identifying cats? This is largely what class imbalance looks like. When you have unequal distribution of labelled data in different classes. For finer details I have also included in this article the results of a really nice paper on “A systematic study of the class imbalance problem in convolutional neural networks” by Mateusz Buda, Atsuto Maki, Maciej A. Mazurowski in which they have systematically investigated the impact of class imbalance on classification performance of CNN.

For now lets just dive into the overview of the solutions. The details of each step are too much to cover in a single article hence I will just outline the steps and I would encourage you to read the paper for a deeper understanding

1. Methods for addressing imbalance

Methods for addressing class imbalance can be divided into two main categories. The first category is data level methods that operate on training set. The other category covers classifier (algorithmic) level methods, which keeps the training dataset unchanged and adjust training or inference algorithms

1.1 Data level methods

a. Oversampling
b. Undersampling

1.2 Classifier level methods

a. Thresholding
b. Cost sensitive learning
c. One-class classification
d. Hybrid of methods

2. Methods we choose for our experiment

The experiment involves these five methods which cover most of the commonly used approaches in the context of deep learning.

  1. Random minority oversampling
  2. Random majority undersampling
  3. Thresholding with prior class probabilities
  4. Oversampling with thresholding
  5. Undersampling with thresholding

3. Datasets used for experiment

Two different dataset are used

  1. MNIST
  2. CIFAR-10

Imbalance was created synthetically.

4. Evaluation metrics and testing

The accuracy metric is misleading with imbalanced dataset. The right metric would be F1-score or even better area under the receiver operating characteristic curve (ROC AUC) which is a plot of the false positive rate to the true positive rate for all possible prediction thresholds.

5. Results

The results showing the impact of class imbalance on classification performance and comparison of methods for addressing imbalance are shown in below graphs

Results from the paper a) MNIST, b) CIFAR-10

5.1 Generalizations of sampling methods

In some cases undersampling and oversampling perform similarly. For classical machine learning models it was shown that oversampling can cause overfitting, especially for minority classes. The results from above paper experiments do not confirm this conclusion for convolutional neural networks.

Figure below compares the convergence of baseline and sampling methods for CIFAR-10 experiments with respect to accuracy (on test set without imbalance). Both oversampling and undersampling methods helped to train a better classifier in terms of performance and generalization. They also made training more stable. As opposed to traditional machine learning methods, in this case oversampling did not lead to overfitting.

Results from the paper Comparison of methods on CIFAR-10 data

6. Conclusion

  • The effect of class imbalance on classification performance is detrimental.
  • The influence of imbalance on classification performance increases with the scale of a task
  • The impact of imbalance cannot be explained simply by the lower total number of training cases and depends on the distribution of examples among classes
  • The method that in most of the cases outperforms all others with respect to multi-class ROC AUC was oversampling
  • For extreme ratio of imbalance and large portion of classes being minority, undersampling performs on a par with oversampling. If training time is an issue, undersampling is a better choice in such a scenario since it dramatically reduces the size of the training set
  • Oversampling should be applied to the level that completely eliminates the imbalance, whereas the optimal undersampling ratio depends on the extent of imbalance. The higher a fraction of minority classes in the imbalanced training set, the more imbalance ratio should be reduced
  • To achieve the best accuracy, one should apply thresholding to compensate for prior class probabilities. A combination of thresholding with baseline and oversampling is the most preferable, whereas it should not be combined with undersampling.

I personally have applied these methods in my Deep learning live projects and heavily used thresholding. With image, training resources becomes an issue and so oversampling at times becomes infeasible.

Speeding Deep Learning inference by upto 20X: Check out my post on Speeding Deep Learning inference using TensorRT. If you are not using TensorRT then your whole deployment is outdated.

X8 aims to organize and build a community for AI that not only is open source but also looks at the ethical and political aspects of it. More such simplified AI concepts will follow. If you liked this or have some feedback or follow-up questions please comment below.

Thanks for Reading!

--

--

abhishek kushwaha
AI Graduate

A Data scientist & Deep learning engineer with Computer vision and NLP specialisation