Batch Normalization in Neural Networks
Motivating Batch Normalization
Let’s consider a feed-forward neural network classifying images as “dog” or “not dog.” During training, the network processes hundreds of image-label pairs, updating its parameters after each iteration. To visualize this, imagine a contour plot representing two network parameters — though, in reality, there are hundreds or thousands of parameters.
In this contour plot, blue and green areas represent low loss (desirable outcomes), while red and orange areas represent high loss. Initially, the network starts at a high-loss position and gradually moves towards lower-loss areas as training progresses. However, this path can be zigzaggy due to the uneven terrain of the contour plot, where changes in one parameter may have a more significant impact on the loss than changes in another parameter. This uneven terrain can lead to instability during training.
The Role of Batch Normalization
Batch normalization addresses this issue by normalizing the outputs of each neuron. This normalization results in a smoother contour plot, making the training process more stable and efficient. Instead of a zigzag path, the parameter updates lead more directly to the minimum loss point.
Technical Details of Batch Normalization
Internal Covariate Shift
One of the reasons for the uneven terrain in the contour plot is internal covariate shift. Let’s explore this with an example:
1. Iteration 1: A neuron has an activation value of 4.
2. Iteration 10: The same neuron’s activation value updates to 9.
3. Iteration 100: The neuron’s activation value further updates to 20.
As training progresses, the distribution of the neuron’s output changes, causing a high variance. High variance in neuron outputs means small parameter changes can lead to significant changes in the output and, consequently, the loss. This is why the contour plot appears stretched along certain parameters.
Normalizing Neuron Outputs
Batch normalization normalizes the outputs of neurons to address this issue. It operates across a batch of samples, calculating the mean and variance of neuron activations within the batch. The activations are then normalized by subtracting the mean and dividing by the standard deviation. Each neuron also has two learnable parameters: gamma (which scales the normalized output) and beta (which shifts it).
The result is a more consistent distribution of neuron activations, leading to more stable and efficient training. Low variance in neuron outputs means small parameter changes result in small, predictable changes in the output and loss.
Batch Normalization in Code
Let’s see how to implement batch normalization using PyTorch.
Defining the Network
We start by importing the necessary libraries and defining the dataset:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
# Load the MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
Next, we define the neural network with batch normalization layers:
class NeuralNetwork(nn.Module):
def __init__(self, use_batch_norm=False):
super(NeuralNetwork, self).__init__()
self.use_batch_norm = use_batch_norm
self.fc1 = nn.Linear(28 * 28, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
if self.use_batch_norm:
self.bn1 = nn.BatchNorm1d(512)
self.bn2 = nn.BatchNorm1d(256)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.fc1(x)
if self.use_batch_norm:
x = self.bn1(x)
x = torch.relu(x)
x = self.fc2(x)
if self.use_batch_norm:
x = self.bn2(x)
x = torch.relu(x)
x = self.fc3(x)
return x
Training the Network
We define the training function to compare the performance with and without batch normalization:
def train_model(model, train_loader, optimizer, criterion, epochs=5):
for epoch in range(epochs):
model.train()
total_loss = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}')
# Initialize models
model_without_bn = NeuralNetwork(use_batch_norm=False)
model_with_bn = NeuralNetwork(use_batch_norm=True)
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_without_bn = optim.SGD(model_without_bn.parameters(), lr=0.01, momentum=0.9)
optimizer_with_bn = optim.SGD(model_with_bn.parameters(), lr=0.01, momentum=0.9)
# Train models
print("Training without Batch Normalization:")
train_model(model_without_bn, train_loader, optimizer_without_bn, criterion)
print("\nTraining with Batch Normalization:")
train_model(model_with_bn, train_loader, optimizer_with_bn, criterion)
Results and Evaluation
After training, we evaluate the models on the test dataset to compare their performance. Typically, the model with batch normalization shows faster convergence and better accuracy due to more stable training dynamics.
Summary
Batch normalization is a powerful technique for stabilizing and accelerating the training of neural networks. Normalizing neuron outputs within each batch reduces internal covariate shift, leading to smoother and more predictable training. This results in faster convergence and improved performance.
Reference