Probability Caliberation On Imbalanced Data

Divyesh Bhatt
The ML Classroom
Published in
2 min readMar 21, 2024

Probability calibration is a technique used in machine learning to adjust the predicted probabilities of a classification model so that they better represent the true likelihood of an event occurring. It's particularly useful when the raw output of a model does not provide accurate probabilities, which can often be the case with models trained on imbalanced datasets or with complex decision boundaries.

Why Is Probability Calibration Important?

Imagine you have a weather forecasting model that predicts the probability of rain tomorrow. If the model says there's an 80% chance of rain, you'd likely take an umbrella when you leave the house. However, if, over time, you notice that it only actually rains 30% of the times the model predicts 80%, you'd lose trust in the model's predictions. Probability calibration aims to align these predicted probabilities with the real-world frequencies, so an 80% prediction means it rains 8 out of 10 times you receive such a forecast.

How It Works

Probability calibration involves applying a transformation to the output of a classification model to ensure that the predicted probabilities match the observed frequencies. Two common methods for probability calibration are:

  • Platt Scaling (Logistic Calibration): This method fits a logistic regression model to the raw predictions of your classifier, effectively squashing the output to a probability scale.
  • Isotonic Regression: This is a non-parametric approach that fits a piecewise constant non-decreasing function to the raw model outputs. It's more flexible than Platt scaling but can be prone to overfitting with small datasets.

Example in Python

Let's demonstrate probability calibration with a simple example using scikit-learn, focusing on Platt Scaling:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.calibration import CalibratedClassifierCV, calibration_curve
import matplotlib.pyplot as plt

# Generate a synthetic dataset
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, weights=[0.9, 0.1], random_state=42)

# Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)

# Train a RandomForestClassifier
clf = RandomForestClassifier()
clf.fit(X_train, y_train)

# Calibrate probabilities using Platt Scaling
calibrated_clf = CalibratedClassifierCV(clf, method='sigmoid', cv='prefit')
calibrated_clf.fit(X_train, y_train)

# Compare the reliability before and after calibration
prob_true_before, prob_pred_before = calibration_curve(y_test, clf.predict_proba(X_test)[:,1], n_bins=10)
prob_true_after, prob_pred_after = calibration_curve(y_test, calibrated_clf.predict_proba(X_test)[:,1], n_bins=10)

# Plotting
plt.plot(prob_pred_before, prob_true_before, marker='o', linewidth=1, label='Before Calibration')
plt.plot(prob_pred_after, prob_true_after, marker='x', linewidth=1, label='After Calibration')
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfectly calibrated')
plt.xlabel('Mean Predicted Probability')
plt.ylabel('Fraction of Positives')
plt.title('Calibration Plots (Reliability Diagram)')
plt.legend()
plt.show()

Conclusion

In this example, a RandomForestClassifier is trained and its probability predictions are then calibrated using Platt Scaling. The calibration curve (or reliability diagram) before and after calibration shows how the calibrated probabilities align more closely with the actual outcomes, aiming for the line that represents perfect calibration. Probability calibration thus enhances the interpretability and reliability of probabilistic predictions, making them more actionable for decision-making processes.

--

--