Batch Normalization… or not?
Still an arcane expedient, is it possible to do without?
Batch Normalization (BN or BatchNorm) is a technique used to normalize the layer inputs by re-centering and re-scaling. This is done by evaluating the mean and the standard deviation of each input channel (across the whole batch), then normalizing these inputs (check this video) and, finally, both a scaling and a shifting take place through two learnable parameters β and γ. Batch Normalization is quite effective but the real reasons behind this effectiveness remain unclear.
The myth of reducing Internal Covariate Shift
Initially, as it was proposed by Sergey Ioffe and Christian Szegedy in their 2015 article, the purpose of BN was to mitigate the internal covariate shift (ICS), defined as “the change in the distribution of network activations due to the change in network parameters during training”. In the figure below, updating the parameters from previous layers results in different input distributions to the next layer.
A reason to scale inputs is to get stable training; unfortunately this may be true in the beginning but as the network trains and the weights move away from their initial values there is no guarantee of stability. So, as the training progresses, the distribution of layer inputs changes due to weights update. BN aims to reduce this trend, that is, BN aims to reduce internal covariate shift.
However, some years later, a paper showed that BN effectiveness had very little to do with reducing ICS. The authors trained networks with random noise injected after BatchNorm layers. Specifically, they perturbed each activation for each sample in the batch using i.i.d. noise sampled from a non-zero mean and non-unit variance distribution. Such noise injections produce a severe covariate shift.
The picture above shows the comparison of distributional stability profiles from VGG networks trained without BatchNorm (Standard), with BatchNorm (Standard + BatchNorm) and with explicit covariate shift added to BatchNorm layers (Standard + “Noisy” BatchNorm). Here, authors sampled
activations of a given layer and visualized their distribution over training steps. The “noisy” BN has distributional instability induced by adding time-varying, non-zero mean and non-unit variance noise independently to each batch normalized activation.
The following picture shows that, surprisingly, the “noisy” BN model nearly matches the performance of standard BN model, despite complete distributional instability. The internal covariate shift in models using BN+noise is similar or even worse… but they perform better in terms of accuracy.
This leads to reject the idea that reducing internal covariate shift would give better results. So, how does BN help?
Furthermore, using a slightly different defintion of internal covariate shift, the authors proved that BN did not reduce ICS (surprisingly, they observed that networks with BN often exhibited an increase in ICS).
Loss landscape and variation of gradients
We just mentioned a different definition of ICS. Define ICS as the difference (as L2-norm) between G (the gradient of the layer parameters that would be applied during a simultaneous update of all layers) and G’ (the same gradient after all the previous layers have been updated with their new values). The figure below shows G and G’, respectively in gray and yellow.
Starting from this definition, some explorations lead to the recognition that BN affects both the variation of loss (loss landscape figure) and variation of gradients of loss (gradient predictiveness figure): the loss varies at a smaller rate and the magnitudes of the gradients are reduced (see picture below). Smoother loss landscapes, usually, allow for higher learning rates, thus reducing training time.
Check this video for more.
Advantages of Batch Normalization
We list some positive properties of BN.
(1) In residual models, Batch Normalization downscales the residual branch. When placed on the residual branch (as typical), batch normalization downscales hidden activations on the residual branches at initialization. This ensures that the network has well-behaved gradients early in training, enabling efficient optimization (details here).
(2) Batch normalization eliminates mean-shift. Activation functions like ReLUs or GELUs lead to non-zero mean activations because of their lack of antisymmetry. The inner product between the activations of independent training examples (inner product close to zero) immediately after the activation function is typically large and positive. This issue gets worse as the network depth increases, introducing a mean-shift in the activations of different training examples on any single channel. Deep networks affected by mean-shift tend to predict the same label for all training examples at initialization. The mean-shift is counteracted by batch normalization, ensuring that the mean activation on each channel is zero across the batch.
(3) Batch normalization has a regularizing effect. The regularizing effect is mostly verified experimentally or conjectured. Nonetheless, there are studies like Luo et al. expressing the regularization strenghts of the batch normalization statistics μ (mini-batch mean) and σ (mini-batch standard deviation).
(4) Batch normalization allows efficient large-batch training. As previously reported, batch normalization has a “taming” effect on the variation of loss: it smoothens the loss landscape. This is important: with a smoother loss landscape, larger learning rates can be used (improving training speed). The ability to train at larger learning rates is essential if one wishes to train efficiently with large batch sizes.
Shortcomings of Batch Normalization
Is BN really an improvement? Paradoxically, BN has many undesirable properties stemming from its dependence on the batch size and interactions between examples (you can read more here):
(i) BN is computationally expensive;
(ii) for small minibatches, there is a significant discrepancy between the distribution of normalized activations during testing and the distribution of normalized activations during training (Sing and Shrivastava, 2019);
(iii) BN breaks the independence between training examples in the minibatch;
(iv) BN is involved in some implementation inconsistencies , see this article for details (it also reports a quite odd fact concerning the incorrect implementation of the batch normalization formula in CNTK deep learning framework);
(v) BN performance is sensitive to the batch size. In particular, normalized networks perform poorly with small batch sizes.
Overcoming Batch Normalization
Due to aforementioned drawbacks, there are several attempts not to implement BN, namely normalizer-free models. Absurd as it may seem, normalizer-free models performance is comparable to that of the normalized models. This may be another sign that the inner mechanisms of some (widely used) deep learning tools are still unclear. For example, this page looks at a type of ResNet using the Adaptive Gradient Clipping strategy and no BN.
Nice video about normalizing inputs.
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
S. Ioffe, C. Szegedy
arXiv:1502.03167v3 [cs.LG], 2015.
How Does Batch Normalization Help Optimization?
S. Santurkar, D. Tsipras, A. Ilyas, A. Madry
arXiv:1805.11604v5 [stat.ML], 2018.
How Does Batch Normalization Help Optimization? - video.
Batch Normalization Biases Residual Blocks Towards the Identity Function in Deep Networks
S. De, S. L. Smith
arXiv:2002.10444v2 [cs.LG], 2020.
Towards Understanding Regularization in Batch Normalization
P. Luo, X. Wang, W. Shao, Z. Peng
arXiv:1809.00846v4 [cs.LG], 2019.
Normalizer-Free ResNets (link).
EvalNorm: Estimating Batch Normalization Statistics for Evaluation
S. Singh, A. Shrivastava
Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV), 2019 (link).
CRADLE: Cross-Backend Validation to Detect and Localize Bugs in Deep Learning Libraries
H. V. Pham, T. Lutellier, W. Qi, L. Tan
2019 IEEE/ACM 41st International Conference on Software Engineering (ICSE), 2019 (link).