Improving Class Imbalance with Class Weights in Machine Learning

Ravi Abhinav
5 min readAug 5, 2023

--

Class imbalance is a common challenge in machine learning, especially when one class heavily outweighs the others in terms of the number of samples. This imbalance can lead to biased model predictions, where the majority class dominates the learning process, and the minority class is often overlooked. In this blog, we will explore the technique of using class weights to address class imbalance and improve the performance of machine learning models.

In class imbalance, one class heavily outweighs the others in terms of the number of samples.

Understanding Class Imbalance

Before we dive into the class weight technique, let’s first understand what class imbalance means. In a classification problem, class imbalance occurs when one class has significantly more samples than the other class(es). For example, in a binary classification problem, if Class A has 90% of the samples and Class B has only 10%, we have a class imbalance issue.

Why Class Imbalance Matters

Class imbalance can negatively impact the performance of machine learning models in several ways. Since the majority class has more samples, the model might become biased towards predicting that class, leading to poor generalization on the minority class. In applications where the minority class is of particular interest (e.g., fraud detection or rare disease diagnosis), this imbalance can have severe consequences.

Using Class Weights to Address Class Imbalance

Class weights offer a simple yet effective technique to deal with class imbalance. The idea is to assign higher weights to the samples of the minority class and lower weights to the majority class during the training process. By doing this, the model pays more attention to the minority class and learns to make better predictions for it.

Steps to Implement Class Weights

  1. Understand the data imbalance: Analyze the class distribution in your dataset and identify the minority and majority classes.
  2. Calculate class weights: Determine the appropriate class weights based on your chosen weighting scheme. Common methods include equal weighting, inverse class frequency, or custom weights based on domain knowledge.
  3. Implement class weights in the model: Adjust the loss function to incorporate the class weights during model compilation. Most machine learning libraries support this functionality.
  4. Train the model: Proceed with training the model using the weighted loss function, which gives more importance to the minority class.
  5. Evaluate the results: After training, evaluate the model’s performance on a separate validation or test set, paying attention to metrics for all classes to ensure balanced performance.

The Logistic Regression Loss Function with Class Weights

In logistic regression, the standard binary cross-entropy loss function can be modified to include class weights. Let’s assume we have two classes, 0 and 1, with class weights w_0 and w_1 respectively. The modified loss function can be expressed as follows:

L(y, p) = -(w_0 * y * log(p) + w_1 * (1 — y) * log(1 — p))

When y = 0(true negative class), the loss is -w_1 * log(p), where p is the predicted probability for class 1. The model will be penalized more for misclassifying the positive class when w_1 is higher, which helps to focus more on the minority class.

When y = 1(true positive class), the loss is -w_0 * log(1 — p), where 1 — p is the predicted probability for class 0. Similarly, the model will be penalised more for misclassifying the negative class when w_0 is higher.

How to Get Class Weights (w_0 and w_1)

Calculating class weights is an important step in handling class imbalance in machine learning. The goal is to determine appropriate weights that give more importance to the minority class and less importance to the majority class during model training. There are different approaches to calculate class weights, and one common method is the Inverse Class Frequency method.

The Inverse Class Frequency Method

The Inverse Class Frequency method calculates class weights based on the number of samples in each class. For a binary classification problem with classes 0 and 1, the formula for calculating class weights is as follows:

weight_0 = total_samples / (2 * class_0_samples)
weight_1 = total_samples / (2 * class_1_samples)

Where:

  • total_samples is the total number of samples in the dataset.
  • class_0_samples is the number of samples in the majority class (class 0).
  • class_1_samples is the number of samples in the minority class (class 1).
  • 2 is the number of classes here.

By dividing the total number of samples by twice the number of samples in each class, we ensure that the sum of the weights for both classes is the same, helping to balance the impact on the model.


import numpy as np

def calculate_class_weights(y):
unique_classes, class_counts = np.unique(y, return_counts=True)
total_samples = len(y)
class_weights = {}

for class_label, class_count in zip(unique_classes, class_counts):
class_weight = total_samples / (2.0 * class_count)
class_weights[class_label] = class_weight

return class_weights

# Assuming 'y' contains the class labels (0s and 1s) for the binary classification problem
class_weights = calculate_class_weights(y)
print("Class weights:", class_weights)

In the code above, the calculate_class_weights function takes an array y containing the class labels (0s and 1s) for the binary classification problem. It then calculates the class weights using the inverse class frequency formula and returns a dictionary with the weights for each class. The total_samples variable stores the total number of samples in the dataset, while unique_classes and class_counts hold the unique class labels and their respective counts.

Hence, calculating class weights is a fundamental step in handling class imbalance in machine learning. The Inverse Class Frequency method is a simple yet effective way to assign appropriate weights to the classes based on their frequency. By using class weights during model training, you can help the model give more attention to the minority class and improve overall performance.

Implementing Class Weights in Logistic Regression

To implement class weights in logistic regression, you can use popular machine learning libraries like TensorFlow/Keras or Scikit-learn.

Using Class Weights in TensorFlow/Keras

from tensorflow.keras import models, layers

# Assuming you have calculated class_weights
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'], class_weight=class_weights)

Using Class Weights in Scikit-learn

from sklearn.linear_model import LogisticRegression

# Assuming you have calculated class_weights
logreg = LogisticRegression(class_weight=class_weights)

Final Thoughts

Class imbalance is a common issue in machine learning, but it can be effectively addressed using class weights. By giving more importance to the minority class during training, models can learn to make better predictions and achieve balanced performance. However, it’s essential to experiment with different techniques and approaches to find the most suitable solution for your specific problem.

Thank you for reading! We hope this blog has provided valuable insights into dealing with class imbalance in machine learning using class weights. If you have any questions or suggestions, feel free to leave a comment below.

--

--

Ravi Abhinav

Experienced Data Scientist with domain expertise in ML, DL, Computer Vision, NLP, AI. Skilled in algorithm design and system building. Seeking new challenges.