A Step-by-Step Guide to Early Stopping in TensorFlow and PyTorch

Vrunda Bhattbhatt
5 min readJan 10, 2024

--

Early Stopping in Neural Networks

Training neural networks can be a thrilling journey, but it’s not without its challenges. One of the most common pitfalls is overfitting: where your model memorizes the training data too closely, failing to generalize well to unseen examples. This leads to poor performance in real-world applications.

Enter early stopping, a powerful technique that helps prevent overfitting by stopping training when the model’s performance on a separate validation dataset stops improving. It’s like putting a wise coach on your team who calls time-out before you get too fixated on one training session and lose sight of the bigger picture.

Why is early stopping so important?

  • Prevents overfitting: By stopping training before the model memorizes the training data, it improves its ability to generalize to unseen data.
  • Reduces training time: Why waste time training a model that’s not improving? Early stopping saves you precious computational resources.
  • Improves model performance: By avoiding overfitting, you get a model that performs better on real-world data.

Now, let’s dive into the step-by-step implementation of early stopping in both TensorFlow and PyTorch

Implementing Early Stopping in U-Net

U-Net is a popular architecture for image segmentation tasks, known for its effectiveness in biomedical image segmentation. Integrating early stopping with U-Net enhances its ability to generalize, making it more robust for practical applications.

Step-by-Step Guide in PyTorch

  1. Import libraries
import torch
import numpy as np
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Upsample, Concatenate
from torch.optim import Adam
import copy

2. Define the U-Net Architecture

class UNet(nn.Module):
def __init__(self, input_channels, output_channels):
super(UNet, self).__init__()

# Contracting path
self.conv1 = Conv2d(input_channels, 64, 3, padding=1)
self.conv2 = Conv2d(64, 64, 3, padding=1)
self.pool = MaxPool2d(2, 2)
self.conv3 = Conv2d(64, 128, 3, padding=1)
self.conv4 = Conv2d(128, 128, 3, padding=1)
self.conv5 = Conv2d(128, 256, 3, padding=1)
self.conv6 = Conv2d(256, 256, 3, padding=1)

# Expanding path
self.up7 = Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv7 = Conv2d(256, 128, 3, padding=1)
self.conv8 = Conv2d(128, 128, 3, padding=1)
self.up8 = Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv9 = Conv2d(128, 64, 3, padding=1)
self.conv10 = Conv2d(64, 64, 3, padding=1)

# Output layer
self.conv11 = nn.Conv2d(64, output_channels, 1)

def forward(self, x):
# Contracting path
x1 = self.conv1(x)
x1 = nn.functional.relu(x1)
x1 = self.conv2(x1)
x1 = nn.functional.relu(x1)
x1 = self.pool(x1)
x2 = self.conv3(x1)
x2 = nn.functional.relu(x2)
x2 = self.conv4(x2)
x2 = nn.functional.relu(x2)
x2 = self.pool(x2)
x3 = self.conv5(x2)
x3 = nn.functional.relu(x3)
x3 = self.conv6(x3)
x3 = nn.functional.relu(x3)

# Expanding path
x4 = self.up7(x3)
x4 = torch.cat([x4, x2], dim=1) # Skip connection
x4 = self.conv7(x4)
x4 = nn.functional.relu(x4)
x4 = self.conv8(x4)
x4 = nn.functional.relu(x4)
x5 = self.up8(x4)
x5 = torch.cat([x5, x1], dim=1) # Skip connection
x5 = self.conv9(x5)
x5 = nn.functional.relu(x5)
x5 = self.conv10(x5)
x5 = nn.functional.relu(x5)

# Output layer
output = self.conv11(x5)
return output

3. Load your data

X_train = torch.from_numpy(np.load('your_training_images.npy'))
y_train = torch.from_numpy(np.load('your_training_segmentations.npy'))
X_val = torch.from_numpy(np.load('your_validation_images

4. Define HyperParameters

input_channels = X_train.shape[1]  # Adjust based on your image channels
output_channels = 1 # For binary segmentation

5. Create UNet model

model = UNet(input_channels, output_channels)

6. Initialize Optimizer and Loss Functions

optimizer = Adam(model.parameters())
criterion = nn.BCELoss()

7. Training loop with early stopping

#Initialize Variables for EarlyStopping
best_loss = float('inf')
best_model_weights = None
patience = 10

# Training Loop with Early Stopping:**
for epoch in range(100):
# Set model to training mode
model.train()

# Forward pass and loss calculation
outputs = model(X_train)
loss = criterion(outputs, y_train.float()) # Convert y_train to float for BCELoss

# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Validation
model.eval() # Set model to evaluation mode
with torch.no_grad(): # Disable gradient calculation for validation
val_outputs = model(X_val)
val_loss = criterion(val_outputs, y_val.float())

# Early stopping
if val_loss < best_loss:
best_loss = val_loss
best_model_weights = copy.deepcopy(model.state_dict()) # Deep copy here
patience = 10 # Reset patience counter
else:
patience -= 1
if patience == 0:
break

# Load the best model weights
model.load_state_dict(best_model_weights)

8. Inference

# Set model to evaluation mode
model.eval()

# Perform inference on new images
with torch.no_grad():
new_images = torch.from_numpy(np.load('your_new_images.npy'))
predictions = model(new_images)

# Process and visualize predictions as needed

Step-by-Step Guide in Tensorflow :

  1. Import libraries
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping

2. Define the U-Net Architecture

def unet_model(input_shape):
inputs = Input(shape=input_shape)

# Contracting path
c1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
c1 = Conv2D(64, 3, activation='relu', padding='same')(c1)
p1 = MaxPooling2D((2, 2))(c1)

c2 = Conv2D(128, 3, activation='relu', padding='same')(p1)
c2 = Conv2D(128, 3, activation='relu', padding='same')(c2)
p2 = MaxPooling2D((2, 2))(c2)

# Bottleneck
c3 = Conv2D(256, 3, activation='relu', padding='same')(p2)
c3 = Conv2D(256, 3, activation='relu', padding='same')(c3)

# Expanding path
u4 = UpSampling2D((2, 2))(c3)
u4 = Concatenate()([u4, c2])
c4 = Conv2D(128, 3, activation='relu', padding='same')(u4)
c4 = Conv2D(128, 3, activation='relu', padding='same')(c4)

u5 = UpSampling2D((2, 2))(c4)
u5 = Concatenate()([u5, c1])
c5 = Conv2D(64, 3, activation='relu', padding='same')(u5)
c5 = Conv2D(64, 3, activation='relu', padding='same')(c5)

outputs = Conv2D(1, 1, activation='sigmoid')(c5) # Single-channel output for segmentation

model = Model(inputs=[inputs], outputs=[outputs])
return model

3. Load your data

X_train = np.load('your_training_images.npy')
y_train = np.load('your_training_segmentations.npy')
X_val = np.load('your_validation_images.npy')
y_val = np.load('your_validation_segmentations.npy')

4. Create and Compile the Model

model = unet_model(input_shape=X_train[0].shape)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

5. Define Earlystopping

early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

6. Train model with earlystopping

model.fit(
X_train, y_train,
epochs=100,
validation_data=(X_val, y_val),
callbacks=[early_stopping]
)

7. Inference

# Load new images for inference
new_images = np.load('your_new_images.npy')

# Make predictions
predictions = model.predict(new_images)

# Process and visualize predictions as needed

Conclusion

The implementation of early stopping in both PyTorch and TensorFlow serves as a strategic approach to enhance the training of neural networks, especially for intricate tasks such as image segmentation using U-Net. This technique hinges on vigilantly monitoring a chosen performance metric, often the validation loss, to gauge the model’s generalization capabilities. When this metric ceases to show improvement, it signals an optimal moment to halt the training process. This methodology not only bolsters the efficiency and effectiveness of the model but also stands as a bulwark against overfitting.

Key Points to Remember:

  1. Metric Selection: Opt for an appropriate metric to monitor, such as validation loss or accuracy. This choice is crucial as it directly influences the decision on when to stop training.
  2. Patience Tuning: Tailor the ‘patience’ parameter thoughtfully, considering the specifics of your dataset and the nuances of the model. This parameter determines how long the training continues without improvement in the monitored metric.
  3. Performance Evaluation: After training, assess your model’s performance on an independent test set. This step is vital for understanding how well the model generalizes to new, unseen data.

By incorporating early stopping into your neural network training regimen, you embark on a journey towards developing models that are not only precise in their predictions but also adept at generalizing to new data. This balance is key to creating robust neural networks fit for real-world applications. With this strategy in hand, you are better equipped to tackle the challenges of overfitting and enhance the overall performance of your neural networks. Embrace this technique and advance your neural network training to new heights of efficiency and effectiveness.

--

--