Batch normalisation formulas derivation

Konstantin Sofeikov
4 min readMar 10, 2019

--

I got stuck for a little while with batch normalisation derivatives expressions. So I set down and wrote down derivations of them. Hopefully it will help anyone who gets stuck as I did. Formulas and the amount of writing looks horrible, but it is actually easier than it seems to be.

Suppose we get the following input:

We are doing the following transformations with our vector:

Whitening of our data

The further transformations include:

So, this is batch normalisation. The task is to find derivatives for back propagation. Although, the final formulas are written in the original paper here, it is worth explaining how these formulas are obtained. So in the original paper we have the following.

The screenshot from the original paper https://arxiv.org/pdf/1502.03167v3.pdf

Let us walk this line by line and work out these formulas ourselves. Let us work on this one first:

The chain rule tells us that:

Now using the chain rule and the derivative of loss function with respect to output we get the following:

Now the juicy part starts. Let us show that:

Again, using chain rule we get:

The first two in this product look familiar, we already met them, it is

Let us now work with the last member of this product:

In the numerator we see a product of two functions, thus we can easily find its derivative wrt variance:

The first summand is obviously zero, since the the expression under the derivative does not depend on variance. The second expression is just polynomial derivative:

And since loss and variance depend on all input vectors, we just sum all these derivatives over the batch and get:

The next part is loss function derivative wrt mean value. This one is the hardest, I think since even variance depends on mean, hence the complicated expression for the derivative:

Again we start from the chain rule:

Again, we have the product of two functions, hence:

In large brackets the second summand in pretty straightforward:

It conveniently matches the first summand in our original derivative expression! Let us now work with the other one. For solving it we will use the original definition of the variance, that is: sum of mean shifted samples divided by the size of the sample.

The last expression is equal to:

Hold on, the first part of this looks exactly like the derivative of loss function with respect to sigma! Compare with the expression we derived before, therefore:

If you carefully expand brackets in the second part of expression, you will get:

And putting everything together, we finally get:

The expression

is derived in the same way as two before, except for it takes enormous amount of writing, so I will omit it here.The last two gradients for linear transformations are basically trivial.

Hope this helps!

--

--