Exploring Batch Normalisation with PyTorch
In continuation of my previous post, in this post we will discuss about “Batch Normalisation” and its implementation in PyTorch.
Batch normalisation is a mechanism that is used to improve efficiency of neural networks. It works by stabilising the distributions of hidden layer inputs and thus improving the training speed.
1. Essence of Batch Normalisation
In neural networks, inputs to each layer are affected by the parameters of all preceding layers and changes to the network parameters amplify as the network becomes deeper. During training of neural networks, weights get updated due to mechanism of backpropagation. Thus, training of neural networks becomes difficult as the the distribution of each layer’s inputs change while training.These changes in the distribution of internal nodes of deep neural network are termed as Internal Covariate Shift.
Internal Covariate Shift is defined as the change in the distribution of network activations due to the change in network parameters during training.
What Batch Normalisation does ?
Batch Normalisation tends to fix the distribution of the hidden layer values as the training progresses. It makes sure that the values of hidden units have standardised mean and variance .
How Batch Normalisation fixes Internal Covariate Shift ?
It accomplishes this via a normalisation step that fixes the means and variances of layer inputs.
Let’s deep dive into this process:-
Explanation:-
- Calculates mean(µ) of x channel in batch (batch size m) (hence Batch Normalisation).
- Calculates variance(σ2) of x channel in batch (batch size m).
- mean(µ) is subtracted from the channel value followed by division with square-root of sum of channel variance(σ2) and ε (to handle divide by zero) .
At this point, transformed channel values have zero mean and unit variance.
- Above value obtained is multiplied by gamma(γ) (scale operation) followed by addition of Beta (β) (shift operation). By using γ and β , the original activation can be restored.
For each channel , two trainable(γ and β) and two non-trainable parameters (µ and σ2) are added to the network.
Why Beta and Gamma have been introduced?
The network training converges faster if its inputs are whitened i.e. linearly transformed to have zero means and unit variances. While it’s not always necessary that zero mean and unit variance for the hidden layer values is best, there are chances that any other distribution might be better too.
To address this, we make sure that the transformation inserted in the network can represent the identity transform( by introducing γ and β — which scale and shift the normalised value). We can make sure hidden units have standardised mean and variance where mean and variance are controlled by two explicit trainable parameters γ and β .
Deep neural nets can be trained faster and generalize better when the distribution of activations is kept normalised during BackPropagation.
2. Batch Normalisation in PyTorch
Using torch.nn.BatchNorm2d , we can implement Batch Normalisation. It takes input as num_features which is equal to the number of out-channels of the layer above it.
Let’s understand impact of Batch Normalisation by considering two neural networks, one with Batch-Norm and other without Batch-Norm layers.
a) Interpreting Model Summary
- Model 1 — without Batch-Norm
- Model 2- with Batch-Norm layers.
I have updated Model 1 network by adding BatchNorm layers between Convolution layer and RELU to construct Model -2.
Inferences:-
- Increase in number of trainable parameters on adding Batch-Norm layer. Earlier we had 48.8k parameters after adding multiple Batch-Norm layers we have 49.3k parameters .
- For each Batch-Norm layer, you can notice number of parameters are double the number of output channels. Eg. For layer BatchNorm2d-2 , there are 16 output channels hence corresponding to that trainable parameters are 32(gammas and betas).
- For Batch-Norm layer, you can notice input shape and output shape both are same.
b) Model efficiency and speed
For both above shown networks (with and without Batch Normalisation) , I ran 10 epochs for MNIST dataset.
- Model 1- without Batch-Norm
- Model 2- with Batch-Norm
Inference:-
- On comparing final accuracies in the last epoch(10th) , without Batch-Norm model reached 98.75% train accuracy and 98.4% test accuracy whereas with Batch-Norm model has 99.92% train accuracy and 99.39% test accuracy.
By using Batch-Norm, model trains faster and it’s capacity is also higher.The point to be noted here is with Batch Normalisation we can increase the model capacity in fewer training steps as compared to without Batch-Norm model hence making model learn faster. While there is need to mend overfitting (for the examples taken here), which we will discuss in coming posts.
There is an option to turn off these learnable parameters by setting affine=False, by default it is set to True. There is a debate on whether Batch-Norm should be used before RELU or after. In this example I have used it before RELU layer.
You can find the corresponding codes for with and without Batch-Norm layer models in this repository.
References