Handling Class Imbalance by Introducing Sample Weighting in the Loss Function
“Nobody is Perfect” This quote not just applies to us humans but also the data that surrounds us. Any data science practitioner needs to understand all of the imperfections present in the data and handle them accordingly in order to get the desired results. Once such imperfection is the inherent Class Imbalance which is highly prevalent in most of the real world datasets. In this blog we will cover different Sample Weighting schemes that can be applied to any Loss Function in order to cater to the Class Imbalance present in your data.
What is the Class Imbalance Problem?
The Class Imbalance problem is a problem that plagues most of the Machine Learning/Deep Learning Classification problems. It occurs when there are one or more classes (majority classes) that are more frequent occurring than the other classes (minority classes). Simply put, there is a skewness towards the majority class.
Consider the following dataset which we will use as the running example throughout this blog. This dataset contains 29 classes for a Multi Label Multi Class Text Classification problem.
As you can see, although not major but this dataset has a slight imbalance, where classes like category_6, category_4, category_27 can be considered as the majority classes and category_7, category_3, category_20 etc can be considered as minority classes.
Why is Class Imbalance a Problem?
So far, we have looked and understood what is the Class Imbalance Problem. But why is it a problem? What is the need to overcome this problem?
Most of the Machine Learning algorithms are based on the inherent assumption that the data is balanced, i.e., the data is equally distributed among all of its classes. When training a model on an imbalanced dataset, the learning becomes biased towards the majority classes. With more number of examples available to learn from, the model learns to perform well on the majority classes but due to the lack of enough examples the model fails to learn meaningful patterns that could aid it in learning the minority classes.
Let us look at how a RoBERTa Sequence Classification model performs on this dataset. Based on our business problem, we consider F0.5 as the evaluation metric which is a harmonic mean of Precision and Recall where Precision is considered twice as important as Recall.
The image above shows how most of the minority classes report a poor performance compared to most of the majority classes.
Why not simply ReSample the data differently?
When we perform Undersampling for the Majority Class, we essentially remove certain number of samples associated with the Majority classes. Oversampling for minority classes on the other hand entails repetition of samples associated with the minority classes.
Although either of the two strategies balance out the dataset, it does not directly tackle the issues caused by Class Imbalance, rather it risks introducing new issues. Since Oversampling introduces duplicate samples, it could easily slow down the training and also lead to overfitting in a model. Undersampling on the other hand removes certain number of samples. This could lead to the model missing out on learning certain important concepts that it could have learnt from the samples that were removed as a result of Undersampling.
So what we could do to overcome these issues is to play around with our Loss Function. We could essentially apply different weights to the loss computed for different samples based on the class these samples belong to. Let’s look at this in detail in the section below.
Sample Weighting in Loss Function
Introducing Sample Weights in the Loss Function is a pretty simple and neat technique for handling Class Imbalance in your training dataset. The idea is to weigh the loss computed for different samples differently based on whether they belong to the majority or the minority classes. We essentially want to assign a higher weight to the loss encountered by the samples associated with minor classes.
Let’s consider a Loss Function for our Multi Label Classification running example. I used PyTorch’s implementation of Binary Cross Entropy: torch.nn.BCEWithLogitLoss which combines a Sigmoid Layer and the Binary Cross Entropy loss for numerical stability and can be expressed mathematically as:
Often times, people get confused between Wn_c(weights) and Pc(pos_weights). Wn_c(weights) are the Sample Weights while Pc(pos_weights) are the Class Weights.
It’s Wn_c which is the Sample Weight that we wish to compute for every sample in a batch which enables us to weigh the contribution of a particular sample towards the overall loss. It can be assigned using the argument ‘weight’ and has to be a Tensor of size N*C (C is Total Number of Classes). ‘pos_weights’ is just the weight for positive examples which is determined based on proportion of samples labeled as a particular class. It weighs the contribution of a particular class towards the loss and must be a vector with length equal to the number of classes. As explained clearly in the Pytorch Documentation: “if a dataset contains 100 positive and 300 negative examples of a single class, then pos_weight for the class should be equal to 300/100 =3 . The loss would act as if the dataset contains 3×100=300 positive examples.” Therefore pos_weight in way acts as if we have resampled the data to account for the class imbalance.
There are different weighting schemes that can be used to compute this Sample Weight. As a part of this project, I tried three different weighting schemes. Inverse of Number of Samples and Inverse of Square Root of Number of Samples are two of the most simplistic and popular weighting schemes. The third one that I tried is a relatively new Weighting Scheme known as Effective Number of Samples weighting scheme.
Inverse of Number of Samples (INS)
As the name suggests, we weight the samples as the inverse of the class frequency for the class they belong to.
The function above shows a simple implementation that computes the weights and normalizes them over different classes. The lines of code below normalizes these sample weights across a batch of samples.
Inverse of Square Root of Number of Samples (ISNS)
Here we weight the samples as the inverse of the Square Root of class frequency for the class they belong to.
Rest of the implementation details related to normalization remains the same as that of INS.
Effective Number of Samples (ENS)
This weighting scheme was introduced in the CVPR’19 paper by Google: Class-Balanced Loss Based on Effective Number of Samples. As seen in the weighting schemes above, the re-weighting strategies rely on the total number of samples present in each class. This paper on the other hand introduces a weighting scheme that relies on the “Effective Number of Samples”. As described in the paper the authors argue that:
“as the number of samples increases, the additional benefit of a newly added data point will diminish. We introduce a novel theoretical framework to measure data overlap by associating with each sample a small neighboring region rather than a single point. The effective number of samples is defined as the volume of samples and can be calculated by a simple formula (1−β^n)/(1−β), where n is the number of samples and β ∈ [0, 1) is a hyperparameter”
The authors suggest experimenting with different beta values: 0.9, 0.99, 0.999, 0.9999.
Putting it all together:
Finally, let us see how these different weighting schemes performed for our running example. Here’s what I was able to achieve by trying out different Sample Weighing Schemes (INS: Inverse Number of Samples, ISNS: Inverse of Square Root of Number of Samples, ENS: Inverse of Effective Number of Samples)
The best performance of 76.122% F0.5 was achieved with INS as the weighting scheme. This is 2.5% improvement over the model trained on the same dataset without any weighting schemes (73.422% F0.5).
Comparing this model with a model trained without any weighting scheme, we achieve a 2.5% improvement. But it is important for us to see whether this improvement is coming from the majority classes or minority classes? Has the model really become better for minority classes or not? The graph below will help us answer that question.
So if you observe the graph above you will find the percentage change in the F0.5, precision and recall for all the 29 classes, sorted in the order of support (frequency) for each class. You can clearly see how we have improved way more for minority classes which clearly shows how the inverse number of samples weighting scheme works better here.
The ISNS and ENS weighting schemes usually work well in case of extreme class imbalance which is not really the case with our dataset. Therefore, a simple weighting scheme such as INS works the best here.
In this blog, we read about the Class Imbalance problem and how it can adversely affect a model’s learning. We then saw how the simple resampling techniques such as Oversampling or Undersampling can only make the existing problem worse by either overfitting or by missing out on learning important concepts. We finally explored different weighting schemes and how we can apply them to solve the Class Imbalance issue.
About Me: Graduated with a Masters in Computer Science from ASU. I am a NLP Scientist at GumGum. I am interested in applying Machine Learning/Deep Learning to provide some structure to the unstructured data that surrounds us.