Essential Math for Machine Learning: The Chain Rule
The Change Lineage Tracker
This article is part of the series Essential Math for Machine Learning.
Introduction
In the complex world of machine learning, mathematics plays the role of a guiding star. It helps algorithms learn and optimize. One vital mathematical concept that powers the optimization process in nearly all machine learning models is the chain rule. Let’s explore what it is, why it’s so important, and how it works its magic.
What is the Chain Rule?
At its core, the chain rule is a calculus technique for calculating the derivative of composite functions. A composite function is just a function within a function. For example, let’s say we have a function:
h(x) = (x² + 1)³
You can break this down into two “nested” functions:
- Outer function: f(u) = u³
- Inner function: g(x) = x² + 1
The chain rule tells us how to find the derivative of h(x), or how a change in the input ‘x’ impacts the final output of our composite function. It states:
dh/dx = df/du * du/dx
In simpler terms:
- Derivative of outer function (with the inner function plugged in): df/du
- Derivative of inner function: du/dx
Let’s see how the chain rule works:
h(x) = (x² + 1)³
f(u) = u³
g(x) = x² + 1
Derivative of outer function: df/du = 3u²
Derivative of inner function: dg/dx = 2x
Apply the chain rule: dh/dx = df/du * du/dx = 3(x² + 1)² * 2x
Why Chain Rule Matters in Machine Learning
Neural networks, the powerhouse behind many machine learning models, are essentially giant webs of composite functions. Understanding how to find the derivatives of these complex functions with respect to their parameters (weights and biases) is the key to making those models “learn.”
Here’s where the chain rule comes in:
- Backpropagation: The heart of neural network training is an algorithm called backpropagation. It calculates how much each parameter in the network contributes to the overall error. This computation of “error signals” involves repeatedly applying the chain rule to efficiently trace those errors back through layers of the network.
- Gradient Descent: Once we know how an error is linked to parameters, we want to make adjustments. Gradient descent, an optimization algorithm heavily used in machine learning, takes these error signals (gradients) and updates the parameters in a direction that minimizes the error. The chain rule makes this gradient calculation possible.
Example: Logistic Regression
Let’s explore the chain rule and gradient descent within the context of logistic regression.
The Model
In logistic regression, we aim to model the probability of a binary outcome (e.g., 0 or 1). The core building blocks are:
- Linear Combination: First, a linear combination of features (with weights and bias) is computed: z = w*x + b
- Sigmoid Function: The linear combination ‘z’ is passed through the sigmoid function to map it into a probability between 0 and 1: σ(z) = 1 / (1 + exp(-z))
- Loss Function: Commonly used is the binary cross-entropy loss function, comparing the model’s predicted probability with the true label.
Gradient Calculation with the Chain Rule
For a single data point, let’s say the true label is ‘y’ and the model’s predicted probability is ŷ. Here’s how the chain rule connects the calculation of gradients:
- Loss with respect to ŷ: This derivative relates to the binary cross-entropy loss function itself.
- ŷ with respect to z: This is the derivative of the sigmoid function.
- z with respect to w (and similarly for b): This is just the derivative of a linear function.
Python Code with SymPy
The code is available in this colab notebook.
import sympy as sp
# Variables
x = sp.symbols('x')
w = sp.symbols('w')
b = sp.symbols('b')
y = sp.symbols('y') # True label
# Linear combination
z = w * x + b
# Sigmoid activation
y_hat = 1 / (1 + sp.exp(-z))
# Binary cross-entropy loss
loss = -(y * sp.log(y_hat) + (1 - y) * sp.log(1 - y_hat))
# Gradients (chain rule in action!)
gradient_w = sp.diff(loss, w)
gradient_b = sp.diff(loss, b)
# Display
print("Gradient with respect to w:", gradient_w)
print("Gradient with respect to b:", gradient_b)
Output:
Gradient with respect to w: -x*y*exp(-b - w*x)/(exp(-b - w*x) + 1) + x*(1 - y)*exp(-b - w*x)/((1 - 1/(exp(-b - w*x) + 1))*(exp(-b - w*x) + 1)**2)
Gradient with respect to b: -y*exp(-b - w*x)/(exp(-b - w*x) + 1) + (1 - y)*exp(-b - w*x)/((1 - 1/(exp(-b - w*x) + 1))*(exp(-b - w*x) + 1)**2)
The Math Behind the Code
Let’s break down the mathematical formulas represented in the SymPy code for logistic regression.
Notations
- x: Input feature (or vector of features)
- w: Weight (or vector of weights)
- b: Bias term
- y: True label (0 or 1)
- ŷ: Predicted probability (output of the model)
- z: Linear combination of input features and weights (input to the sigmoid)
Formulas
Linear Combination:
z = w * x + b
Sigmoid Function:
ŷ = σ(z) = 1 / (1 + exp(-z))
Binary Cross-Entropy Loss:
loss = -(y * log(ŷ) + (1 - y) * log(1 - ŷ))
Derivative of loss = -(y * log(ŷ) + (1 — y) * log(1 — ŷ)) with respect to ŷ:
∂loss/∂ŷ = -(y / ŷ) + ((1 - y) / (1 - ŷ))
Derivative of ŷ = σ(z) = 1 / (1 + exp(-z)) with respect to z:
∂ŷ/∂z = σ(z) * (1 - σ(z))
Derivative of z = w * x + b with respect to w:
∂z/∂w = x
Gradient of loss with respect to w:
∂loss/∂w = ∂loss/∂ŷ * ∂ŷ/∂z * ∂z/∂w
Similar process is for b.
Chain Rule:
- The gradients are where the chain rule shines. The derivative of the loss function with respect to ‘w’ (or ‘b’) involves “peeling back” the layers of the logistic regression model:
- From loss to predicted probability (ŷ)
- From predicted probability (ŷ) to the linear combination (z)
- From the linear combination (z) to the parameters (w and b)
Conclusion
The Chain Rule is a cornerstone of calculus that finds extensive application in machine learning. It is key to understanding and implementing the optimization algorithms that enable machine learning models to learn from data. By breaking down the process of computing derivatives of composite functions into manageable steps, the Chain Rule provides a clear path for gradient calculation in complex models, making it an indispensable tool for anyone delving into the field of machine learning. Whether you are training a simple linear regression model or a complex deep neural network, a solid grasp of the Chain Rule will greatly enhance your understanding of how these models learn and improve.