ELI5: Chain Rule of derivative
In my previous post about backpropagation, I mentioned the Chain rule of calculus. Before explaining the chain rule, I am going to talk about the intuition behind derivatives. Don’t worry! I will try to keep it short and simple.
Derivatives
What is the derivative of a function? The derivative is a measure of how the function behaves with little changes in the parameters of the function. Let’s dumb it down further.
Let 𝑓(𝑥) be a function with a single parameter, 𝑓(𝑥)=2𝑥
for x = 2, 𝑓(𝑥)=4, but if we increase x by 0.00001, the 𝑓(𝑥) increases by 0.00002. Similarly for 𝑓(𝑥)=𝑥² and x=2, 𝑓(𝑥)=4, but if we increase x by 0.00001, the change in 𝑓(𝑥) gets increased by 0.0000400001.
In this image, y=f(x)=x². So, if we increase x by dx amount, the value y is increased by dy.
For x=2 and dx=0.00001, dy≈0.00004. Now the derivative is defined as dy/dx. So, the derivative or slope at x=2 is close to 4 which matches the formula of derivative
f’(x²) = 2x
Derivatives formulas can be found here.
Now that we have an idea of derivatives, Let’s talk about the chain rule.
If a function is defined as y=3x ; z=y² then what is the derivative of z with respect to x (dz/dx)?
If we need to compute derivatives of composite functions, we use the chain rule.
The chain rule states
( f( g( x ) ) )’ = f’( g( x ) ) * g’( x )
We can see that in the above derivation also. To compute derivate of z with respect to x, first we compute the derivative of z for y and then multiply with derivative of y for x.
We need the chain rule to compute the derivative or slope of the loss function.
The loss function for logistic regression is defined as
L(y,ŷ) = — (y log(ŷ) + (1-y) log(1-ŷ))
where
The slope of the loss function is defined as the derivative of L with respect to w.
Hope it helps. All the codes and scripts to create the visualization and equation can be found here.
Reference:
Writing Mathematics on Medium https://medium.com/@cihansoylu/writing-mathematics-on-medium-9e078b22d738