Batch Normalization

Sai Rahul
7 min readJun 18, 2020

--

Batch Normalization is a technique to improve the speed, performance and stability of neural networks [1]. It is introduced in this classic paper [2]. This post is not an introduction to Batch Normalization but tries to clarify some common questions, misconceptions and provides helpful tips and details when you are getting started. Some of the things I discussed here are specific to CNNs and it applies to other kinds of network models as well.

1. Should BatchNorm come before or after ReLU?

This is the common confusion about the BN (Batch Normalization). In the original paper it is applied before the nonlinearity. It is described as

We add the BN transform immediately before the nonlinearity, by normalizing x = Wu + b

But in practice it works better [3][4] if we apply after the nonlinear layer. In the recent implementations the authors are using it after the nonlinear layer [4] as well. The reason is that, the point of adding a BN layer is to make the data zero mean and unit variance. Applying nonlinearity (like ReLU) after the BN defeats the purpose because the output of nonlinearity is input to another layer.

Adding to the confusion, if you use standard Resnet models from torchvision implementations the BN is applied before the nonlinear layer as described in the original layer.

2. Should we train BN while fine tuning the networks ?

There is no simple answer here and it depends on the problem. If the original task and fine tuned task are not related, then you should train the BN layer. Let’s say you are fine tuning X-Ray abnormality using the ImageNet then you should fine tune the BN layer. Take another example, if you are fine tuning a dog classification task using ImageNet then you don’t need to fine tune the BN layers. In the second case you may be destroying the BN layer statistics by fine tuning it. Always experiment if you are in doubt.

Batch size affects BN statistics. This seems obvious but it’s an important thing! Consider this use case, when we are fine tuning we often start with small image sizes (like 224) and train a model. Later we fine tune this with higher resolution images later (like 1024). When we use high resolution images we often need to reduce the batchsize to fit the model in GPU memory. This affects the BN statistics and your training may not improve the overall accuracy. If you are doing stage wise training like i described earlier it’s often better to freeze the BN layer rather than training. This seemingly inconspicuous detail affected me a lot of times.

3. What’s the difference between inference mode vs freezing the BN layer ?

We do BN differently for training and testing.

  • During training we use mini batch to calculate the mean and variance and normalize the input. Ideally we need to use the mean and variance of the entire dataset for normalization. But we use SGD for training and the statistics of mini-batch and dataset may be different and leads us to shift and scale the input unnecessarily and slows the training. So we use only batch statistics during training but maintain a running mean and variance for testing (population statistics).
  • As explained above, we use population statistics for testing.

The behavior of when to use batch statistics or population statistics depends on the library and we need to know some library specific gotchas.

PyTorch: We need to do both torch.no_grad and model.eval() to achieve what you want.

with torch.no_grad: This disables gradient calculations so we won’t be updating alpha and beta parameters. But we will still use batch statistics for forward pass. Also we will be updating the running list of mean and variance (population statistics)

model.eval(): In Eval mode we use entire population statistics.

For testing we need both the steps to achieve what we want [12]

Tensorflow: The behavior is different for Tensorflow 1.x and 2.x.

* In Tensorflow 1.x, layer.trainable = False would freeze the layer but would not switch it to inference mode.

* In Tensorflow 2.x, setting trainable = False on the layer means that the layer will subsequently run in inference mode.

4. BNand Dropout

BN introduces some randomness in the training process because we compute the statistics (mean and variance) on minibatch. This is again multiplied by alpha and beta parameters. In effect the network sees slightly different variations of the same input depending on the items in the batch. This will force the model to ignore the noise. Dropout does the same by randomly “dropping out” or omitting units during the training process [9]. Although BN is not designed as a regularizer it does have a regularization effect.

There are some implications in using both Dropout and BN together. In Dropout we randomly ignore some units and this messes up while calculating the BN statistics. So the order of BN and Dropout are important. BN should precede Dropout like BatchNorm -> Dropout. You can read about this interaction in detail [11].

In practice, if you have a lot of data we can ignore the Dropout and use BN alone. For example Resnet [10] does the same. But if you have lesser data it’s better to use both BN and Dropout together otherwise BN alone might overfit to the train data. If you use both BN and Dropout the order is important as explained above.

5. BN reduces Internal Covarient Shift

The most widely accepted explanation of BN success as well as its original motivation, relates to Internal covariate shift (ICS)[8]. Informally, ICS refers to the change in the distribution of layer inputs caused by updates to the preceding layers [8]. But in contrast the experimentation results show it never reduces ICS and in fact it increases ICS [8]. The actual reason might be related to smoothing of the optimization surface as explained in [8].

I find the explanation of BN by Goodfellow [6] the most convincing and intuitive. In short, BN reduces the higher order interactions between layers. This explanation is taken from [6],

Let’s take a toy neural network with 5 layers, every layer has one weight matrix and the output is matrix multiplication of all layers and there is no non linearity.

a -> b -> c -> d -> e

The value of “a” determines the input statistics for the layer “d” i.e the mean and variance of input to “d” are determined by “a”. But if we add BN after each layer like,

a -> bn -> b -> bn -> c -> bn -> d -> bn -> e

then the input statistics for “d” doesn’t depend on “a”. Because it’s always going to be zero mean and unit standard deviation. This is a powerful concept because the gradient descent is blind about the interactions between multiple layers. Lets consider layer “a” again. During back propagation if the gradient of “b”, “c”, “d” or “e” is small or close to zero then the gradient update step for “a” going to be small or vice versa. These small or big changes to “a” in turn affect subsequent layers in next forward pass again. These interactions make the learning algorithm difficult to converge. We can use second order (hessian) information for getting the accurate picture of pairwise interactions and update. But the complexity of calculating the Hessian matrix is quadratic. This only considers pairwise interactions (hessian). In our toy example, we have 5th order interactions and using higher order information quickly becomes unwieldy and there is no straight forward math or linear algebra solvers to solve this.

By using BN we are reducing the interaction between the layers. For example the mean and standard deviation of “d” doesn’t depend on the value of “a” as the input to any layer is always normalized. The statistics of “d” depends only on alpha and beta parameters of the previous batch norm layer. Remember alpha and beta parameters in the BN layer scales the input and these are learned parameters (By now it becomes apparent why we need these two additional parameters. If we do zero mean and unit normalization then our expressiveness of the network reduces). So back propagation needs to adjust two parameters to get the mean and standard deviation it wants in that layer. It also allows SGD to safely make large steps in previous layers like a, b, c without destroying the statistics at the higher level.

6. Other type of normalizations … Layer normalization, Group normalization

We can’t apply the BN directly to recurrent networks as the statistics of the input batch are time dependent. We need to have a separate BN parameters for each time step. This also doesn’t work if the sequence length differs in training and inference time. Layer Normalization is developed to handle these cases.

Like i warned in second section of this article, BN inherently depends on training batch size. Its difficult to compute BN statistics if the batch size is very small (Ex. when the model is big like BERT, GPT) or if we are using multiple GPUs (which is very common these days). In Group Normalization we divide the incoming channels into multiple groups and normalize each group independently. Check Group Normalization (Paper Explained) YouTube video for detail explanation.

BN is one of the important techniques that made training large neural networks easier and faster. This article tries to clarify common questions and misconceptions when you are getting started. Hopefully it helped you.

References

1. https://en.wikipedia.org/wiki/Batch_normalization

2. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

3. https://forums.fast.ai/t/questions-about-batch-normalization/230/10

4. https://github.com/keras-team/keras/issues/1802#issuecomment-187966878

5. https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/

6. Ian Goodfellow lectures about Batch Normalization and Convolutional Networks.

7. Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift

8. How Does Batch Normalization Help Optimization?

9. https://en.wikipedia.org/wiki/Dilution_(neural_networks)

10. Deep Residual Learning for Image Recognition

11. Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift

12. https://stackoverflow.com/questions/55627780/evaluating-pytorch-models-with-torch-no-grad-vs-model-eval

13. https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization

14. https://blog.paperspace.com/busting-the-myths-about-batch-normalization/

15. Layer Normalization

16. Group Normalization Paper

17. Group Normalization (Paper Explained)

--

--