Boosting Neural Network Training: The Power of Batch Normalization

Ashwin N
5 min readJul 21, 2023

--

Batch normalization is a technique used in neural networks to improve the training process and overall performance of the model. It addresses the internal covariate shift problem, which occurs during the training of deep neural networks. The internal covariate shift refers to the change in the distribution of layer inputs as the model learns, which can slow down the training process and make it more difficult for the model to converge.

Why Batch Normalization?

To relate this to a daily example, let’s consider a chef preparing a meal for a large party. The chef needs to ensure that each dish is cooked perfectly, which can be challenging when cooking for a large number of guests. In this analogy, the neural network is the chef, and the guests represent the training data.

Imagine the chef starts cooking with a recipe that hasn’t been adjusted for the number of guests. As the chef progresses and prepares each dish, the ingredients’ proportions may not be suitable for the current number of guests, leading to inconsistencies in the taste and quality of the dishes. This situation can be time-consuming, and the chef might have to keep adjusting the recipe after every new batch of guests arrives.

In the context of neural networks, without batch normalization, each layer’s input distribution changes with each batch of data during training. This means that the neural network has to constantly adapt to the varying input distributions, making the training process less stable and slower.

Now, let’s introduce batch normalization into the cooking scenario. Batch normalization for the chef would be like having a kitchen assistant who continuously ensures that the ingredient proportions for each dish are appropriately adjusted based on the total number of guests. With batch normalization, the chef can maintain consistent dish quality and reduce the need for constant adjustments.

Similarly, in neural networks, batch normalization normalizes the input for each layer with respect to the mean and standard deviation calculated over each mini-batch of data. This helps to stabilize the learning process by keeping the input distributions consistent and reducing internal covariate shift. As a result, the neural network can train more efficiently, converge faster, and perform better on the test data.

This is similar to how the chef can efficiently cook a large meal with consistent quality using batch normalization.

Exploding Gradient Problem

When training a deep neural network, a common issue is the exploding gradient problem, which occurs when the weights of the network become too large. This can cause wild fluctuations in weight values, and in extreme cases, the loss function may return NaN, indicating an overflow error.

To prevent the exploding gradient problem, one of the reasons for scaling input data is to ensure a stable start to training. By scaling the input, such as pixel values from 0–255 to between -1 and 1, we avoid the immediate creation of huge activation values that could lead to exploding gradients. However, as the network trains and the weights move away from their initial values, a phenomenon called covariate shift can occur.

Covariate shift is akin to carrying a tall pile of books and getting hit by gusts of wind. As you compensate for the wind’s force, the pile becomes slightly more unstable with each gust. Similarly, in neural networks, each layer assumes consistent input distributions from the layer beneath during updates. But if the activation distributions shift significantly in a certain direction, this can lead to runaway weight values and a collapse of the network.

In summary, scaling input data helps prevent exploding gradients at the start of training, but covariate shift can still occur during later iterations, potentially causing instability in weight values and overall network collapse. Addressing covariate shift is essential to maintain a stable and well-performing neural network.

Batch Normalization Process

Let’s go through a step-by-step example of the batch normalization process with a simple neural network.

Suppose we have a neural network with two hidden layers and an output layer. The architecture is as follows:

Input Layer (4 features) -> Hidden Layer 1 (6 neurons) -> Hidden Layer 2 (4 neurons) -> Output Layer (3 classes)

1. Forward Pass:
Let’s assume we have a mini-batch of 8 samples as input to the neural network. During the forward pass, the input data flows through the network, and activations are calculated for each layer.

a. Input Layer:
The mini-batch of 8 samples is fed into the input layer, and the activations for the hidden layers are computed.

b. Hidden Layer 1:
The activations from the input layer are multiplied by the corresponding weights for Hidden Layer 1. Then, the bias is added to each neuron, and the output is passed through an activation function (e.g., ReLU).

c. Hidden Layer 2:
Similar to Hidden Layer 1, the activations from Hidden Layer 1 are multiplied by the corresponding weights for Hidden Layer 2. The bias is added, and the output is passed through an activation function.

d. Output Layer:
Finally, the activations from Hidden Layer 2 are multiplied by the weights for the Output Layer, and the bias is added. The output is then processed through the activation function (e.g., softmax) to get the probabilities for each class.

2. Batch Normalization:
Now, before calculating the loss and updating the weights during backpropagation, we apply batch normalization to stabilize and improve the training process.

a. Calculate Batch Mean and Variance:
For each neuron in the hidden layers, we calculate the mean and variance of the activations for the entire mini-batch.

b. Normalize the Activations:
Using the batch mean and variance, we normalize the activations for each neuron by subtracting the mean and dividing by the square root of the variance. This centers the activations around zero and scales them.

c. Scale and Shift (Gamma and Beta):
To allow the model to learn the optimal scale and shift for each normalized activation, we introduce two learnable parameters: gamma (scale) and beta (shift). These parameters are learned during training.

d. Apply Gamma and Beta:
The normalized activations are multiplied by the gamma parameter and then added to the beta parameter. This step reintroduces flexibility into the network and allows it to represent the original activations better.

3. Backward Pass:
After the batch normalization step, we continue with the backward pass, calculating the gradients and updating the weights using techniques like gradient descent.

By applying batch normalization, the network can train more effectively and converge faster. It reduces the internal covariate shift, making it easier to find an appropriate learning rate and avoid the exploding gradient problem, leading to better and more stable training of the neural network.

References:

Sergey Ioffe and Christian Szegedy, “Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift,” 11 February 2015, https://arxiv.org/abs/1502.03167.

#AI #DeepLearning #BatchNormalization #NeuralNetworks #Tech #MachineLearning

--

--

Ashwin N

Lead Data Scientist 🧙‍♂️ | Exploring the AI Wonderland 🔬 | Sharing Insights on Data Science 📊 | Join me in https://medium.com/@ashwinnaidu1991