Essential Math for Machine Learning: Cross-Entropy and KL Divergence
Predicted vs. True Distribution
This article is part of the series Essential Math for Machine Learning.
Introduction
Machine learning models often speak the language of probabilities. They tell us how likely an image contains a cat, or how probable it is that a particular customer will click on an ad. To train these models effectively, we need a way to guide them towards making accurate predictions. This is where loss functions, like cross-entropy, come into play.
What is Cross-Entropy?
In essence, cross-entropy measures the “difference” between a model’s predicted probability distribution and the true, underlying probability distribution of our data. Let’s break this down:
- True Distribution: In classification tasks, the true distribution is simple. If we’re classifying images of cats and dogs, and an image is a cat, the true distribution might be [1, 0] (1 for ‘cat’, 0 for ‘dog’).
- Predicted Distribution: Our model outputs its version of this distribution, expressed as probabilities (e.g., [0.8, 0.2] means the model thinks 80% chance of ‘cat’, 20% chance of ‘dog’).
- Cross-Entropy’s Job: Cross-entropy calculates how surprised the model is by its incorrect predictions. The more closely the predicted distribution aligns with the true distribution, the lower the cross-entropy value. When a model makes perfect predictions, the predicted distribution aligns perfectly with the true distribution. In this case, Cross-Entropy(Q, P) becomes equal to the entropy of the true distribution Entropy(P). So Entropy(P) is the lower bound of Cross-Entropy(Q, P).
The Formula
Cross-entropy is calculated as:
Cross-Entropy(p, q) = Σ p(x) * log(1/q(x))
Where:
p(x)
: The true probability of an event x (in classification, this is usually either 1 or 0).q(x)
: The model's predicted probability of event x.log(1/q(x))
: The length of encoding for event x based on the predicted distributionq
.
The weighted sum is the expected encoding length for an event x based on the predicted distribution and the true distribution.
Note that when p and q are the same, cross-entropy falls back to the entropy form:
Entropy(p) = Σ p(x) * log(1/p(x))
What Cross-Entropy Measures
Cross-entropy directly measures the average number of bits you’d need to encode messages from a true distribution P if you used an encoding scheme based on a different (usually estimated) distribution Q.
Optimal Encoding
Imagine you’re designing a system to transmit a simple weather forecast with four possible conditions:, Sunny (S), Rainy (R), Cloudy (C), Snowy (W). Let’s say, based on historical data, you know the true probability distribution is:
- P(S) = 0.5 (50% chance of sunny)
- P(R) = 0.25 (25% chance of rainy)
- P(C) = 0.15 (15% chance of cloudy)
- P(W) = 0.1 (10% chance of snowy)
A smart encoding scheme, like Huffman coding, would use this knowledge of probabilities to assign shorter codes to events with high probabilities and longer codes to events with low probabilities:
- S: 0
- R: 10
- C: 110
- W: 111
On average, this encoding minimizes the number of bits needed to transmit your forecasts.
Imperfect Model
Now, let’s say your machine learning model is learning to predict the weather. It might initially output probabilities like:
- Q(S) = 0.3
- Q(R) = 0.3
- Q(C) = 0.2
- Q(W) = 0.2
If you used this model’s distribution (Q) to create an encoding scheme, it wouldn’t be as efficient as the one based on the true distribution (P).
KL Divergence
Kullback-Leibler (KL) Divergence measures the difference between the optimal encoding and the non-optimal encoding:
KL(P||Q) = Cross-Entropy(P,Q) - Entropy(P)
It tells you how many extra bits you’d need, on average, to encode data from a true distribution (P) if you used an encoding scheme based on an approximate distribution (Q) instead of the optimal one. In machine learning, KL Divergence is a way to gauge how far a model’s predictions (Q) are from the true distribution of the data (P). Because Entropy(P) is a constant, minimizing KL(P||Q) is the same as minimizing Cross-Entropy(P,Q).
Why Cross Entropy?
By minimizing cross-entropy, your machine learning model gets better at understanding the true patterns in the data. This translates into a theoretical ‘encoding’ of the data that gets closer to the optimal, saving bits and improving communication efficiency in the grand scheme of information theory. Cross-entropy has a few advantages for classification problems:
- Penalizes wrong predictions heavily: As the difference between the true and predicted probabilities increases, the log term in the cross-entropy formula increases rapidly.
- Handles probabilities well: Its design aligns with how probabilities work. Predictions close to 0 or 1 receive minimal penalty when they’re correct.
- Good for optimization: The shape of the cross-entropy function makes it smooth and amenable to gradient-based optimization methods commonly used in machine learning.
When to Use Cross-Entropy
Cross-entropy is your go-to loss function for:
- Classification problems: Especially those with multiple classes (multi-class classification).
- Models that output probabilities: Neural networks with softmax activation at the output layer are perfect examples.
Its effectiveness in these scenarios stems from its ability to highlight differences between the true and predicted probability distributions, encouraging the model to make predictions that are as close to the actual distribution as possible.
Python Implementation from Scratch
Let’s build a bare-bones Python implementation of cross-entropy loss. The code is available in this colab notebook.
import numpy as np
import matplotlib.pyplot as plt
def binary_cross_entropy(y_true, y_pred):
# Clipping for numerical stability
y_pred = np.clip(y_pred, 1e-15, 1 - 1e-15)
return y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred)
def loss_function(y_true, y_pred):
# The average cross entropy across the data set.
return -np.mean(binary_cross_entropy(y_true, y_pred))
def evaluate_accuracy(y_true, y_pred):
# Threshold predictions to binary outcomes
y_pred_binary = (y_pred >= 0.5).astype(int)
accuracy = np.mean(y_true == y_pred_binary)
return accuracy
def plot_predictions(y_true, y_pred):
plt.figure(figsize=(10, 6))
plt.scatter(range(len(y_pred)), y_pred, c='r', label='Predicted Probability')
plt.scatter(range(len(y_true)), y_true, marker='x', label='Actual Label')
plt.title('Binary Classification')
plt.xlabel('Sample Index')
plt.ylabel('Probability / Binary Label')
plt.legend()
plt.grid(True)
plt.show()
# Example data
y_true = np.array([1, 0, 1, 1, 0])
y_pred = np.array([0.9, 0.1, 0.8, 0.8, 0.2])
# Calculate and print loss
loss = binary_cross_entropy(y_true, y_pred)
print(f"Binary Cross Entropy Loss: {loss}")
# Evaluate accuracy
accuracy = evaluate_accuracy(y_true, y_pred)
print(f"Accuracy: {accuracy}")
# Visualize predictions
plot_predictions(y_true, y_pred)
Output:
Binary Cross Entropy Loss: 0.17603033705165633
Accuracy: 1.0
Conclusion
Cross entropy is a vital concept in machine learning, serving as a loss function that quantifies the difference between the actual and predicted probability distributions. Its mathematical foundation provides insight into how models are trained to make accurate predictions, especially in classification tasks. By implementing cross entropy from scratch, practitioners can gain deeper understanding and appreciation of the intricacies involved in training machine learning models.