Batch Normalization in Neural Networks

Tiroshan Madushanka
zero-to
Published in
4 min readJun 3, 2024

--

Motivating Batch Normalization

Let’s consider a feed-forward neural network classifying images as “dog” or “not dog.” During training, the network processes hundreds of image-label pairs, updating its parameters after each iteration. To visualize this, imagine a contour plot representing two network parameters — though, in reality, there are hundreds or thousands of parameters.

Contour plot of two network parameters before applying the batch normalization.

In this contour plot, blue and green areas represent low loss (desirable outcomes), while red and orange areas represent high loss. Initially, the network starts at a high-loss position and gradually moves towards lower-loss areas as training progresses. However, this path can be zigzaggy due to the uneven terrain of the contour plot, where changes in one parameter may have a more significant impact on the loss than changes in another parameter. This uneven terrain can lead to instability during training.

The Role of Batch Normalization

Batch normalization addresses this issue by normalizing the outputs of each neuron. This normalization results in a smoother contour plot, making the training process more stable and efficient. Instead of a zigzag path, the parameter updates lead more directly to the minimum loss point.

Contour plot of two network parameters after applying the batch normalization.

Technical Details of Batch Normalization

Internal Covariate Shift

One of the reasons for the uneven terrain in the contour plot is internal covariate shift. Let’s explore this with an example:

1. Iteration 1: A neuron has an activation value of 4.
2. Iteration 10: The same neuron’s activation value updates to 9.
3. Iteration 100: The neuron’s activation value further updates to 20.

As training progresses, the distribution of the neuron’s output changes, causing a high variance. High variance in neuron outputs means small parameter changes can lead to significant changes in the output and, consequently, the loss. This is why the contour plot appears stretched along certain parameters.

Distribution of Neuron Activation before and after the Batch Normalization.

Normalizing Neuron Outputs

Batch normalization normalizes the outputs of neurons to address this issue. It operates across a batch of samples, calculating the mean and variance of neuron activations within the batch. The activations are then normalized by subtracting the mean and dividing by the standard deviation. Each neuron also has two learnable parameters: gamma (which scales the normalized output) and beta (which shifts it).

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (https://arxiv.org/pdf/1502.03167)

The result is a more consistent distribution of neuron activations, leading to more stable and efficient training. Low variance in neuron outputs means small parameter changes result in small, predictable changes in the output and loss.

Batch Normalization in Code

Let’s see how to implement batch normalization using PyTorch.

Defining the Network

We start by importing the necessary libraries and defining the dataset:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Load the MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

Next, we define the neural network with batch normalization layers:

class NeuralNetwork(nn.Module):
def __init__(self, use_batch_norm=False):
super(NeuralNetwork, self).__init__()
self.use_batch_norm = use_batch_norm
self.fc1 = nn.Linear(28 * 28, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)

if self.use_batch_norm:
self.bn1 = nn.BatchNorm1d(512)
self.bn2 = nn.BatchNorm1d(256)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.fc1(x)
if self.use_batch_norm:
x = self.bn1(x)
x = torch.relu(x)
x = self.fc2(x)
if self.use_batch_norm:
x = self.bn2(x)
x = torch.relu(x)
x = self.fc3(x)
return x

Training the Network

We define the training function to compare the performance with and without batch normalization:

def train_model(model, train_loader, optimizer, criterion, epochs=5):
for epoch in range(epochs):
model.train()
total_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}')


# Initialize models
model_without_bn = NeuralNetwork(use_batch_norm=False)
model_with_bn = NeuralNetwork(use_batch_norm=True)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_without_bn = optim.SGD(model_without_bn.parameters(), lr=0.01, momentum=0.9)
optimizer_with_bn = optim.SGD(model_with_bn.parameters(), lr=0.01, momentum=0.9)

# Train models
print("Training without Batch Normalization:")
train_model(model_without_bn, train_loader, optimizer_without_bn, criterion)
print("\nTraining with Batch Normalization:")
train_model(model_with_bn, train_loader, optimizer_with_bn, criterion)

Results and Evaluation

After training, we evaluate the models on the test dataset to compare their performance. Typically, the model with batch normalization shows faster convergence and better accuracy due to more stable training dynamics.

Summary

Batch normalization is a powerful technique for stabilizing and accelerating the training of neural networks. Normalizing neuron outputs within each batch reduces internal covariate shift, leading to smoother and more predictable training. This results in faster convergence and improved performance.

Reference

[1] https://arxiv.org/pdf/1502.03167

--

--

Tiroshan Madushanka
zero-to

Cloud, Distributed Systems, Data Science, Machine Learning Enthusiastic | Tech Lead- Rozie AI Inc. | Research Assistant - NII |Lecturer - University of Kelaniya