Effective Strategies for Handling Class Imbalance in Machine Learning Models

Suvendu Pati
The Deep Hub
Published in
8 min readJul 20, 2024
“Uncommon among commons”

Class imbalance in machine learning occurs when the classes in a dataset are not represented equally. This can lead to issues where the model does not learn enough from the underrepresented class, potentially ignoring it in predictions. For example, in cancer detection, a model might always predict “normal” if cancer cases are rare, achieving high accuracy but poor performance on the minority class. Handling class imbalance is crucial because the cost of misclassifying rare but critical cases, like cancer, is high.

In this article we will comprehensively go through the challenges that class imbalance can bring to the table and of course, various ways to handle them.

Challenges with Class Imbalance

1. Insufficient Examples for Learning

When a class is underrepresented, the model doesn’t get enough examples to learn from. Imagine teaching a child to identify different types of fruits, but you only show them one or two examples of apples compared to hundreds of examples of oranges. The child will struggle to recognize apples in the future because they haven’t seen enough of them to learn their characteristics.

In machine learning, this means that the model might not get a strong enough signal from the underrepresented class before it has to make a decision. As a result, the model might assume this rare class doesn’t exist at all when making predictions.

2. Stuck with Learning Simple Patterns

Machine learning models can get stuck learning simple patterns, especially when there’s a class imbalance. Let’s take an example of cancer detection using lung X-rays. Suppose 99.99% of the X-rays are of normal lungs and only 0.01% show cancerous lungs. A model could easily “learn” to always predict the lungs as normal, achieving 99.99% accuracy.

However, this is not useful because the model hasn’t learned anything meaningful about detecting cancer. It’s just exploiting the simple heuristic that most X-rays are normal. Overcoming this simplistic solution can be difficult, as adding a small amount of cancerous cases may initially seem to make the model perform worse due to the complexity introduced. Alternatively, this can be very hard for an optimisation algorithm such as gradient descent to reach an optimal solution.

3. Asymmetric Cost for Wrong Predictions

In real-world applications, the cost of making wrong predictions is not the same for all classes. Consider our cancer detection example again. Misclassifying a cancerous lung as normal (a false negative) has a much higher cost compared to misclassifying a normal lung as cancerous (a false positive). In other words, missing a cancer diagnosis can be life-threatening, while a false alarm might only lead to some additional tests.

If the model’s loss function is not configured to reflect this asymmetric cost, it will treat errors for both classes equally. This is not desirable because we want the model to be more cautious about missing cancer cases. The loss function should penalize the misclassification of cancerous cases more heavily to ensure the model makes better predictions for this minority class.

Handling Class Imbalance

Fortunately, there are several strategies to tackle this issue. This section explores various methods to handle class imbalance, breaking down complex concepts into simple terms.

1. Choosing the Right Metrics

Using the correct metrics is crucial when dealing with imbalanced classes. Accuracy is not always the best metric for imbalanced data since it can be misleading. For instance, a model that always predicts the majority class will have high accuracy but poor performance on the minority class. Here are some metrics in case of class imbalance:

  • Recall measures the proportion of actual positives correctly identified. It’s vital when missing a positive case is costly. For example detecting fraud or cancer. Recall is calculated as:
  • Precision measures the proportion of positive predictions that are actually correct. It’s important when false positives are costly such as spam classification. We do not want important emails in spam folder. Precision is calculated as:
  • Probability Threshold: Adjusting the probability threshold can help increase recall. In case of cancer detection, the probability threshold needs to be set less than 0.5 as the model is more likely to output low probability for actual cancer cases. The threshold can be optimised using the ROC curve.
  • ROC Curve plots the true positive rate (Recall) against the false positive rate. When the classification is perfect there’s only one horizontal line at Recall = 1.0. When choosing the threshold, choose threshold where the curve that is closest to the perfect line. It’s useful for visualizing performance but is better suited for balanced classes.
ROC Curve (Picture obtained from Wikipedia)
  • Precision-Recall Curve is more suitable for imbalanced classes. Precision and Recall are directly concerned with the positive class (often the minority class in imbalanced datasets). Thus, the PR curve gives a clearer picture of performance on the minority class.

2. Data-Level Methods: Resampling

Resampling techniques involve modifying the dataset to balance the class distribution. There are two main methods — Undersampling & Oversampling:

  • Undersampling reduces the number of instances of the majority class. One such method is Tomek Links. It removes majority class instances that are close to minority class instances, clarifying the decision boundary but potentially losing important data points.
  • Oversampling increases the number of instances of the minority class. SMOTE (Synthetic Minority Over-sampling Technique) generates new minority class instances by creating synthetic examples. However, it can lead to overfitting.

Caution: When resampling, it’s crucial not to evaluate the model on resampled data, as it can lead to biased results and overfitting.

As undersampling can cause loss of important data points and oversampling can cause overfitting there are multiple sophisticated techniques to deal with them. Following are two such advanced resampling techniques:

  • Two-phase Learning: This technique consists of two phases as the name suggests. Phase 1 consists of training the model on resampled data. and Phase 2 is fine tuning the model with the original data.
  • Dynamic Sampling: This one dynamically undersamples the well-performing class and oversamples the underperforming class to balance learning. It is a way to show the model less of what it has learnt and more of what it hasn’t learnt.

3. Algorithm-Level Methods

Algorithm-level methods are strategies that adjust the learning process of a machine learning model to handle class imbalance without altering the distribution of the training data. This can be achieved using two approaches:

  • Make the model penalise more for misclassifying the hard-to-classify examples. For example: Cost-Sensitive Learning & Class-Balance Learning
  • Make the model focus on learning hard-to-classify examples and less on easy examples. For example: Focal Loss

Cost-Sensitive Learning: Cost-sensitive learning involves modifying the loss function to incorporate a cost matrix, where different misclassification errors are assigned different penalties. This approach was initially proposed by Charles Elkan.

Each element C(i,j) of the cost matrix represents the cost of classifying an instance of class i as class j. If class i is the majority class and class j is the minority class and the cost of C(j,i)​ (classifying j as i) is twice as the cost of C(i,j) (classifying as j). Then C(i,j) = 1 and C(j,i) = 2. Of course, the cost for correct predictions (C(i,i) and C(j,j)) is zero. In case of binary cross-entropy, the loss function is adjusted using the cost matrix as following:

Similarly Cost-Sensitive Learning can be applied to other loss functions. This ensures that the model pays more attention to correctly classifying the minority class. The problem with this technique is that cost matrix needs to be defined manually, which is different for different tasks.

Class-Balanced Loss: In this method the loss function is adjusted to balance the influence of each class. For example, in a balanced binary cross-entropy loss, each class’s contribution is divided by the number of samples in that class. Class-Balanced Loss weighs the loss function by the inverse of the number of samples in each class. Class-balanced adjusted binary cross-entropy loss would look like:

where N_{i}​ and N_{j}​ are the number of negative (0) and positive (1) samples, respectively. The weights are implicitly determined by the class frequencies which is a advantage over Cost-Sensitive Leaning as the weights are automatically assigned.

Note: In both Cost-Sensitive and Class-Balanced, the loss for positive class is scaled is such a way that each instance of class 1 contributes more to the total loss compared to an unweighted scenario, reflecting its lower prevalence in the dataset. Whereas, the loss for class 0, reflecting its higher prevalence in the dataset.

Focal Loss: Focal loss is designed to focus more on hard-to-classify examples and less on easy examples. It down-weights the loss assigned to well-classified examples, allowing the model to pay more attention to the misclassified or hard-to-classify examples.

The focal loss function modifies the standard cross-entropy loss by adding a modulating factor (1−p_{t})^γ, where p_{t} is the model’s estimated probability for the true class and γ is a focusing parameter that adjusts the rate at which easy examples are down-weighted. For binary classification, the focal loss is defined as:

When an example is correctly classified with high confidence (e.g., p_{t}​ is close to 1), the term (1−p_{t})^γ becomes very small, thus reducing the loss contribution from these examples. In case, where p_{t}​ is small (i.e., the model is unsure or wrong), the loss contribution remains significant, encouraging the model to focus on these harder examples.

Conclusion

Class imbalance in machine learning can lead to significant challenges and costs. Relying solely on accuracy as a metric can be misleading, making it essential to choose appropriate evaluation metrics. To address class imbalance, various strategies can be employed, ranging from data-level methods like resampling to algorithm-level approaches such as cost-sensitive learning, class-balance loss, and focal loss. Experimenting with these techniques and their combinations can help you find the best solution tailored to your specific problem, ultimately enhancing model performance and fairness.

Reference

  1. Huyen, C. (2022). Designing Machine Learning Systems (pp. 102–113). O’Reilly Media.
  2. Fraud Detection Handbook. (n.d.). Cost-sensitive learning. In Fraud Detection Handbook. Retrieved from https://fraud-detection-handbook.github.io/fraud-detection-handbook/Chapter_6_ImbalancedLearning/CostSensitive.html

--

--

Suvendu Pati
The Deep Hub

I share reflections and insights from my journey in Data Science and Trading, documenting what I learn along the way