Fine-Tuning a Pre-Trained Model in PyTorch: A Step-by-Step Guide for Beginners

Santosh Premi Adhikari

--

Fine-tuning pre-trained models can save time and resources while achieving high performance on new tasks.

Introduction

Fine-tuning is a powerful technique that allows you to adapt a pre-trained model to a new task. In this tutorial, we’ll guide you through fine-tuning a ResNet18 model for digit classification using PyTorch.

Step 1: Setting Up the Environment and Model

Let’s begin by importing the necessary libraries and modifying the pre-trained ResNet18 model.

import torch  
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# Load a pre-trained ResNet18 model
model = models.resnet18(pretrained=True)

# Modify the last layer for MNIST (10 classes)
model.fc = nn.Linear(model.fc.in_features, 10)

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Explanation: ResNet18, pre-trained on ImageNet, is adapted for the new task by modifying the final layer to predict 10 classes (digits 0–9).

Step 2: Preparing the Dataset

ResNet18 expects images sized 224x224. We resize the MNIST dataset accordingly.

# Transform images to 224x224 and normalize  
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

# Load MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

Explanation: The dataset is resized and normalized to match the pre-trained model’s input requirements.

Step 3: Setting Up the Loss Function and Optimizer

Define the loss function and optimizer. We use Adam, a popular choice for fine-tuning.

# Define loss function and optimizer  
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

Explanation: CrossEntropyLoss is ideal for classification tasks, and Adam optimizer ensures efficient fine-tuning.

Step 4: Fine-Tuning the Model

Train the model for a few epochs. Adjust the number of epochs if training takes too long.

num_epochs = 5  

for epoch in range(num_epochs):
running_loss = 0.0
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

running_loss += loss.item()

scheduler.step()
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}")

print('Fine-tuning complete!')

Explanation: We loop through the dataset, perform forward and backward passes, and update the model’s weights. This step adapts the pre-trained model to our specific task.

Step 5: Saving the Fine-Tuned Model

Save the fine-tuned model for later use.

torch.save(model.state_dict(), 'finetuned_resnet18_mnist.pth')  
print('Model saved!')

Explanation: We save the model’s state dictionary (parameters) to a file. This allows us to load it later without retraining.

Step 6: Evaluating the Model

Check how well the model performs on unseen data.

model.eval()  
correct = 0
total = 0

with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct / total:.2f}%')

Explanation: We set the model to evaluation mode and calculate accuracy on the test set to see how well the model generalizes to new data.

Step 7: Making Predictions

Finally, use the fine-tuned model to make predictions on a single image.

# Load the model for inference
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load('finetuned_resnet18_mnist.pth'))
model.eval()
model = model.to(device)

# Make a prediction on a single image from the test set
test_image, _ = testset[0] # Get the first image from the test set
test_image = test_image.unsqueeze(0).to(device) # Add a batch dimension and move to device

output = model(test_image)
_, predicted = torch.max(output, 1)

print('Predicted label:', predicted.item())

Explanation: We load the saved model, set it to evaluation mode, and use it to predict the class of a new image.

Conclusion

Fine-tuning allows you to leverage pre-trained models for new tasks with minimal effort. Here’s what we covered:

  1. Setting up a pre-trained model for a new task.
  2. Preparing the dataset.
  3. Training and fine-tuning the model.
  4. Saving and loading the fine-tuned model.
  5. Making predictions.

Found this helpful? Leave a clap or share your thoughts in the comments!
Thank You.

--

--

No responses yet