Does Batch Norm really depends on Internal Covariate Shift for its success?

Aakash Bindal
Techspace
Published in
5 min readSep 14, 2019

The popular belief is that Batch Norm reduces ICS for better training but recently MIT has published a research paper which shows that relation between Batch Norm and ICS is tenuous.

Broadly speaking, Batch Norm is a technique that aims to improve the training of deep neural networks by stabilizing the distribution of layer inputs. And, this is done with the augmentation of additional layers that set the mean equals to 0 and variance equals to 1. Most people believe that Batch Norm success is due to the reduction of ICS(the change in the distribution of layer inputs caused by updating the layer inputs of preceding layers).

It is conjectured that the continuous ICS leads to a constant shift of the underlying training problem and is thus believed to have decremental effect on training.

ICS of activation i at time t is defined as the difference,

where,

Here, L(in the equation L is written in calligraphy form) is loss, Wᵗ₁ , . . . , Wᵗₖ be the parameters of each of the k layers and (xᵗ , yᵗ ) be the batch of input-label pairs used to train the network at time t. G(t,i) corresponds to the gradient of the layer parameters that would be applied during a simultaneous update of all layers. On the other hand, G`(t,i) is the same gradient after all the previous layers have been updated with their new values.

Is Batch Norm directly related to ICS, as we might think?

The conventional understanding of Batch Norm is that it should increase the correlation between G and G`, thus reducing ICS.

To prove that, Batch Norm does not rely on ICS for its success we perform a little experiment:

A random network is trained with injecting random noise after Batch Norm layers. It specifically perturb the activation of each sample in the batch using noise sampled from a non-zero mean and non-unit variance distribution(because it is believed that ideally mean should be 0 and variance should be unity for better training).

This addition of noise injection will produce severe covariate shift that skews every activation at each step.

Consequently, every unit in the layer experiences a different distribution of inputs at each time step. We then measure the effect of this deliberately introduced distributional instability on BatchNorm’s performance.

It can be seen that both the networks with Batch Norm and Noisy Batch Norm performance is almost the same. Also, they both perform significantly better than the network with no Batch Norm(standard).

Clearly, these findings are hard to reconcile with the belief that the performance gain due to Batch Norm stems from the increased stability of layer inputs distributions.

Moreover,

An experiment was done with VGG and a deep linear network(DLN) of 25 layered and surprisingly the results were contradicted, as we all believed that Batch Norm reduces the ICS which is the main motivation behind using it, here, the results shows that Batch Norm is increasing ICS, moreover, it still has higher accuracy than Non-Batch Norm model with everything else same.

Here, ICS is measured with and without Batch Norm using cosine angle(ideally 1) and l₂ difference of the gradients(ideally 0) for a layer.

Surprisingly, the networks with Batch-Norm exhibiting an increase in ICS particularly in the case of DLN. Network without Batch-Norm shows no change in the ICS in entire training. And, network with Batch Norm shows no convincing relations between G and G`.

So, If Batch Norm success is not due to ICS, then what is the reason behind its success?

There are many claims regarding the Batch Norm that includes prevention of exploding or vanishing gradients, keeping activations away from saturation regions of non-linearities, robustness of different settings of hyperparameters like Learning rate and initialization. But, all these factors are merely the consequence of applying Batch Norm.

There must be another player in this scenario,

Indeed, we identify the key impact that BatchNorm has on the training process: it reparametrize the underlying optimization problem to make its landscape significantly more smooth. The first manifestation of this impact is improvement in the Lipschitzness of the loss function. That is, the loss changes at a smaller rate and the magnitudes of the gradients are smaller too.

where, L_cap(calligraphy) refers to loss for Batch Norm network and L(calligraphy) refers to loss for identical non-Batch Norm network.

The biggest player in the success of Batch Norm is β-smoothness(f is said to be β-smooth if its gradient is β-lipchitz). These smoothening effects impact the performance of the training algorithm in a major way. To understand why, recall that in a vanilla (non-BatchNorm) deep neural network, the loss function is not only non-convex but also tends to have a large number of “kinks”, flat regions, and sharp minima. This makes gradient descent–based training algorithms unstable, e.g., due to exploding or vanishing gradients, and thus highly sensitive to the choice of the learning rate and initialization.

Improved Lipschitzness makes the gradients more reliable and predictive i.e. when we take a large step in the direction of computed gradient, this gradient direction remains a fairly accurate estimate of the actual gradient direction after taking the step. This enables the algorithm to take larger steps without fear of running into sudden change in the loss landscape such as sharp local minimum or flat region.

This allows us to use wide range of learning rate and initialization techniques, it is worth noting that the properties of Batch Norm we previously discussed can be viewed as the manifestation of this β-smoothness.

--

--

Aakash Bindal
Techspace

I am Computer Vision and Image Processing enthusiast. I like to learn the core of every algorithm which is basically mathematics.