Transfer learning with Practical Implementation
In many real-world situations you will not want to train a whole convolutional network from scratch, unless you happen to have a very large dataset and a lot of compute resources. Training a modern CNN on a dataset like ImageNet takes days or weeks on multiple GPUs. Instead, in most cases you will be using Transfer Learning.
Why Transfer Learning?
- Efficiency: Training models from scratch can be computationally expensive and time-consuming. Transfer learning allows for faster training times.
- Data Scarcity: In many cases, large labeled datasets are not available for the specific task at hand. Transfer learning can help achieve good performance with limited data.
- Performance: Models pre-trained on large datasets can capture more complex patterns and generalize better to new tasks, often leading to higher accuracy.
How Transfer Learning Works
Transfer learning typically involves the following steps:
- Select a Pre-trained Model: Choose a model that has been pre-trained on a large dataset, such as ImageNet for image tasks or BERT for natural language processing tasks.
- Freeze Initial Layers: Freeze the weights of the initial layers of the pre-trained model to retain the learned features.
- Replace Final Layers: Replace the final layers of the model with new layers that are suitable for the target task.
- Fine-Tune the Model: Fine-tune the entire model or just the new layers on the target dataset.
Let’s explore a practical implementation of transfer learning with structured and well-explained code. This section will guide you through the process step-by-step, ensuring a professional approach to utilizing transfer learning in your projects.
Import Libraries:
import torchvision.datasets
import torchvision.transforms as T
import torchvision.models
import torch.optim
from torch import nn
import matplotlib.pyplot as plt
Data Loader Function:
def compute_mean_and_std():
"""
Compute per-channel mean and std of the dataset (to be used in transforms.Normalize())
"""
cache_file = "mean_and_std.pt"
if os.path.exists(cache_file):
print(f"Reusing cached mean and std")
d = torch.load(cache_file)
return d["mean"], d["std"]
folder = get_data_location()
ds = datasets.ImageFolder(
folder, transform=T.Compose([T.ToTensor()])
)
dl = torch.utils.data.DataLoader(
ds, batch_size=1, num_workers=multiprocessing.cpu_count()
)
mean = 0.0
for images, _ in tqdm(dl, total=len(ds), desc="Computing mean", ncols=80):
batch_samples = images.size(0)
images = images.view(batch_samples, images.size(1), -1)
mean += images.mean(2).sum(0)
mean = mean / len(dl.dataset)
var = 0.0
npix = 0
for images, _ in tqdm(dl, total=len(ds), desc="Computing std", ncols=80):
batch_samples = images.size(0)
images = images.view(batch_samples, images.size(1), -1)
var += ((images - mean.unsqueeze(1)) ** 2).sum([0, 2])
npix += images.nelement()
std = torch.sqrt(var / (npix / 3))
# Cache results so we don't need to redo the computation
torch.save({"mean": mean, "std": std}, cache_file)
return mean, std
def get_transforms(rand_augment_magnitude):
# These are the per-channel mean and std of CIFAR-10 over the dataset
mean, std = compute_mean_and_std()
# Define our transformations
return {
"train": T.Compose(
[
# All images in CIFAR-10 are 32x32. We enlarge them a bit so we can then
# take a random crop
T.Resize(256),
# take a random part of the image
T.RandomCrop(224),
# Horizontal flip is not part of RandAugment according to the RandAugment
# paper
T.RandomHorizontalFlip(0.5),
# RandAugment has 2 main parameters: how many transformations should be
# applied to each image, and the strength of these transformations. This
# latter parameter should be tuned through experiments: the higher the more
# the regularization effect
T.RandAugment(
num_ops=2,
magnitude=rand_augment_magnitude,
interpolation=T.InterpolationMode.BILINEAR,
),
T.ToTensor(),
T.Normalize(mean, std),
]
),
"valid": T.Compose(
[
# Both of these are useless, but we keep them because
# in a non-academic dataset you will need them
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean, std),
]
)
}
def get_data_loaders(
batch_size: int = 32, valid_size: float = 0.2, num_workers: int = -1, limit: int = -1, rand_augment_magnitude: int = 9
):
"""
Create and returns the train_one_epoch, validation and test data loaders.
:param batch_size: size of the mini-batches
:param valid_size: fraction of the dataset to use for validation. For example 0.2
means that 20% of the dataset will be used for validation
:param num_workers: number of workers to use in the data loaders. Use -1 to mean
"use all my cores"
:param limit: maximum number of data points to consider
:return a dictionary with 3 keys: 'train_one_epoch', 'valid' and 'test' containing respectively the
train_one_epoch, validation and test data loaders
"""
if num_workers == -1:
# Use all cores
num_workers = multiprocessing.cpu_count()
# We will fill this up later
data_loaders = {"train": None, "valid": None, "test": None}
base_path = Path(get_data_location())
# Compute mean and std of the dataset
mean, std = compute_mean_and_std()
print(f"Dataset mean: {mean}, std: {std}")
data_transforms = get_transforms(rand_augment_magnitude)
train_data = datasets.ImageFolder(
base_path,
# YOUR CODE HERE: add the appropriate transform that you defined in
# the data_transforms dictionary
transform=data_transforms["train"] # -
)
# The validation dataset is a split from the train_one_epoch dataset, so we read
# from the same folder, but we apply the transforms for validation
valid_data = datasets.ImageFolder(
base_path,
# YOUR CODE HERE: add the appropriate transform that you defined in
# the data_transforms dictionary
transform=data_transforms["valid"] # -
)
# obtain training indices that will be used for validation
n_tot = len(train_data)
indices = torch.randperm(n_tot)
# If requested, limit the number of data points to consider
if limit > 0:
indices = indices[:limit]
n_tot = limit
split = int(math.ceil(valid_size * n_tot))
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
valid_sampler = torch.utils.data.SubsetRandomSampler(valid_idx) # =
# prepare data loaders
data_loaders["train"] = torch.utils.data.DataLoader(
train_data,
batch_size=batch_size,
sampler=train_sampler,
num_workers=num_workers,
)
data_loaders["valid"] = torch.utils.data.DataLoader(
valid_data,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=num_workers,
)
return data_loaders
When the dataset is small it is difficult to split the dataset 3 ways (train, validation and test) because you might not have enough data to train, validate or test. Instead, I can use k-fold cross validation: you divide the dataset in k parts, then I repeat the training k times, each time taking k-1 parts as your training dataset and the remaining k-th part as your validation. You record your metrics for each iteration, and then I use the average performance as your metric.
To make things faster for this here we won’t be doing k-fold cross-validation, but we will still only split the data in train and validation (in practice we will execute only one of the k training run for k-fold cross validation).
For simplicity we provide the code in above code snippet , feel free to look into it to see how the data is loaded.
data_loaders = get_data_loaders(batch_size=32, rand_augment_magnitude=15)
classes = ["daisy", "tulip", "dandelion", "sunflower", "rose"]
Create the model and substitute the head
Let’s now load a pretrained model and substitute its head. Let’s use a ResNet from torchvision and substitute it with a fully-connected layer (Linear) with the right input and output dimension:
model = torchvision.models.resnet50(pretrained=True)
n_classes = len(classes)
n_inputs = model.fc.in_features
# Feel free to experiment with more complicated heads
model.fc = nn.Linear(n_inputs, n_classes)
Freeze the backbone and thaw the head
Now we need to freeze all layers except the one we just added. Let’s keep track of the parameters we are freezing so we will be able to free them later:
frozen_parameters = []
for p in model.parameters():
# Freeze only parameters that are not already frozen
# (if any)
if p.requires_grad:
p.requires_grad = False
frozen_parameters.append(p)
print(f"Froze {len(frozen_parameters)} groups of parameters")
# Now let's thaw the parameters of the head we have
# added
for p in model.fc.parameters():
p.requires_grad = True
Train
Now we can train our model. We start with the usual learning rate finder:
loss = nn.CrossEntropyLoss()
losses = lr_finder(1e-5, 0.1, 100, loss, model, data_loaders)
Learning Rate Finder
import torch
import numpy as np
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
import torch.optim as optim
import copy
def lr_finder(min_lr, max_lr, n_steps, loss, model, data_loaders):
if torch.cuda.is_available():
model.cuda()
# Save initial weights so we can restore them at the end
torch.save(model.state_dict(), "__weights_backup")
# specify optimizer
optimizer = optim.SGD(model.parameters(), lr=min_lr)
# We create a learning rate scheduler that increases the learning
# rate at every batch.
# Find the factor where min_lr r**(n_steps-1) = max_lr
r = np.power(max_lr / min_lr, 1 / (n_steps - 1))
def new_lr(epoch):
"""
This should return the *factor* by which the initial learning
rate must be multipled for to get the desired learning rate
"""
return r ** epoch
# This scheduler increases the learning rate by a constanct factor (r)
# at every iteration
lr_scheduler = LambdaLR(optimizer, new_lr)
# Set the model in training mode
# (so all layers that behave differently between training and evaluation,
# like batchnorm and dropout, will select their training behavior)
model.train()
# Loop over the training data
losses = {}
train_loss = 0.0
keep_going = True
n = 0
while keep_going:
for batch_idx, (data, target) in tqdm(
enumerate(data_loaders["train"]),
desc="Learning rate finder",
total=len(data_loaders["train"]),
leave=True,
ncols=80,
):
# move data to GPU if available
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
# 1. clear the gradients of all optimized variables
optimizer.zero_grad() # -
# 2. forward pass: compute predicted outputs by passing inputs to the model
output = model(data) # =
# 3. calculate the loss
loss_value = loss(output, target) # =
# 4. backward pass: compute gradient of the loss with respect to model parameters
loss_value.backward() # -
# 5. perform a single optimization step (parameter update)
optimizer.step() # -
train_loss = train_loss + (
(1 / (n + 1)) * (loss_value.data.item() - train_loss)
)
losses[lr_scheduler.get_last_lr()[0]] = train_loss
# Stop if the loss gets too big
if train_loss / min(losses.values()) > 10:
break
if n == n_steps - 1:
keep_going = False
break
else:
# Increase the learning rate for the next iteration
lr_scheduler.step()
n += 1
# Restore model to its initial state
model.load_state_dict(torch.load('__weights_backup'))
return losses
Plot Learning Rate Finder Graph:
_ = plt.plot(losses.keys(), losses.values())
_ = plt.xscale("log")
#You might need to adjust this to cover the right
# portion of the y axis to better see the decrease
# in the loss before it shoots up again
_ = plt.ylim([1.57, 1.75])
Looks like a good learning rate is 0.005. Let’s train first the head for a few epochs, to bring its weight (that right now are randomly initialized) to a place where they interact well with the feature extraction part:
Optimize Helper Functions:-
def train_one_epoch(train_dataloader, model, optimizer, loss):
"""
Performs one epoch of training
"""
# Move model to GPU if available
if torch.cuda.is_available():
model.cuda() # -
# Set the model in training mode
# (so all layers that behave differently between training and evaluation,
# like batchnorm and dropout, will select their training behavior)
model.train() # -
# Loop over the training data
train_loss = 0.0
for batch_idx, (data, target) in tqdm(
enumerate(train_dataloader),
desc="Training",
total=len(train_dataloader),
leave=True,
ncols=80,
):
# move data to GPU if available
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
# 1. clear the gradients of all optimized variables
optimizer.zero_grad() # -
# 2. forward pass: compute predicted outputs by passing inputs to the model
output = model(data) # =
# 3. calculate the loss
loss_value = loss(output, target) # =
# 4. backward pass: compute gradient of the loss with respect to model parameters
loss_value.backward() # -
# 5. perform a single optimization step (parameter update)
optimizer.step() # -
# update average training loss
train_loss = train_loss + (
(1 / (batch_idx + 1)) * (loss_value.data.item() - train_loss)
)
return train_loss
def valid_one_epoch(valid_dataloader, model, loss):
"""
Validate at the end of one epoch
"""
# During validation we don't need to accumulate gradients
with torch.no_grad():
# set the model to evaluation mode
# (so all layers that behave differently between training and evaluation,
# like batchnorm and dropout, will select their evaluation behavior)
model.eval() # -
# If the GPU is available, move the model to the GPU
if torch.cuda.is_available():
model.cuda()
# Loop over the validation dataset and accumulate the loss
valid_loss = 0.0
for batch_idx, (data, target) in tqdm(
enumerate(valid_dataloader),
desc="Validating",
total=len(valid_dataloader),
leave=True,
ncols=80,
):
# move data to GPU if available
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
# 1. forward pass: compute predicted outputs by passing inputs to the model
output = model(data) # =
# 2. calculate the loss
loss_value = loss(output, target) # =
# Calculate average validation loss
valid_loss = valid_loss + (
(1 / (batch_idx + 1)) * (loss_value.data.item() - valid_loss)
)
return valid_loss
def optimize(data_loaders, model, optimizer, loss, n_epochs, save_path, interactive_tracking=False):
def after_subplot(ax: plt.Axes, group_name: str, x_label: str):
"""Add title xlabel and legend to single chart"""
ax.set_title(group_name)
ax.set_xlabel(x_label)
ax.legend(loc="center right")
# initialize tracker for minimum validation loss
if interactive_tracking:
liveloss = PlotLossesKeras(outputs=[MatplotlibPlot(after_subplot=after_subplot)])
else:
liveloss = None
valid_loss_min = None
logs = {}
# Learning rate scheduler: setup a learning rate scheduler that
# reduces the learning rate when the validation loss reaches a
# plateau
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # =
optimizer, "min", verbose=True, threshold=0.01 # -
) # -
for epoch in range(1, n_epochs + 1):
train_loss = train_one_epoch(
data_loaders["train"], model, optimizer, loss
)
valid_loss = valid_one_epoch(data_loaders["valid"], model, loss)
# If the validation loss decreases by more than 1%, save the model
if valid_loss_min is None or (
(valid_loss_min - valid_loss) / valid_loss_min > 0.01
):
# Save the weights to save_path
torch.save(model.state_dict(), save_path) # -
valid_loss_min = valid_loss
# Update learning rate, i.e., make a step in the learning rate scheduler
scheduler.step(valid_loss) # -
# Log the losses and the current learning rate
if interactive_tracking:
logs["loss"] = train_loss
logs["val_loss"] = valid_loss
logs["lr"] = optimizer.param_groups[0]["lr"]
liveloss.update(logs)
liveloss.send()
def one_epoch_test(test_dataloader, model, loss):
# monitor test loss and accuracy
test_loss = 0.
correct = 0.
total = 0.
# we do not need the gradients
with torch.no_grad():
# set the model to evaluation mode
model.eval() # -
# if the GPU is available, move the model to the GPU
if torch.cuda.is_available():
model = model.cuda()
# Loop over test dataset
# We also accumulate predictions and targets so we can return them
preds = []
actuals = []
for batch_idx, (data, target) in tqdm(
enumerate(test_dataloader),
desc='Testing',
total=len(test_dataloader),
leave=True,
ncols=80
):
# move data to GPU if available
if torch.cuda.is_available():
data, target = data.cuda(), target.cuda()
# 1. forward pass: compute predicted outputs by passing inputs to the model
logits = model(data) # =
# 2. calculate the loss
loss_value = loss(logits, target).detach() # =
# update average test loss
test_loss = test_loss + ((1 / (batch_idx + 1)) * (loss_value.data.item() - test_loss))
# convert logits to predicted class
# NOTE: the predicted class is the index of the max of the logits
pred = logits.data.max(1, keepdim=True)[1] # =
# compare predictions to true label
correct += torch.sum(torch.squeeze(pred.eq(target.data.view_as(pred))).cpu())
total += data.size(0)
preds.extend(pred.data.cpu().numpy().squeeze())
actuals.extend(target.data.view_as(pred).cpu().numpy().squeeze())
print('Test Loss: {:.6f}\n'.format(test_loss))
print('\nTest Accuracy: %2d%% (%2d/%2d)' % (
100. * correct / total, correct, total))
return test_loss, preds, actuals
Optimizer:
n_epochs = 5
lr = 0.005
optimizer = torch.optim.Adam(model.parameters(), lr)
optimize(
data_loaders,
model,
optimizer,
loss,
n_epochs,
'initial.pt',
interactive_tracking=True
)
_ = one_epoch_test(data_loaders['valid'], model, loss)
Testing: 100%|██████████████████████████████████| 27/27 [00:06<00:00, 4.11it/s]Test Loss: 0.318999
Test Accuracy: 90% (780/864)
Just by training the new head for a few epochs we reached already a very good performance. Now let’s unfreeze the parameters we have frozen before:
# Thaw parameters we had frozen before
for p in frozen_parameters:
p.requires_grad = True
and let’s fine-tune the model. Note how we are using a learning rate that is much smaller than what we used before. This is to avoid getting large steps, that would destroy the advantage of using a pre-trained network:
n_epochs = 10
lr = 0.005 / 100
optimizer = torch.optim.Adam(model.parameters(), lr)
optimize(
data_loaders,
model,
optimizer,
loss,
n_epochs,
"best.pt",
interactive_tracking=True,
)
_ = one_epoch_test(data_loaders['valid'], model, loss)
Testing: 100%|██████████████████████████████████| 27/27 [00:06<00:00, 4.11it/s]
Test Loss: 0.185599 Test Accuracy: 95% (826/864)
With fine-tuning we have significantly increased our performance to 95%!
Can you explain why? -
The dataset we have is fairly small, so when training from scratch the network does not have enough examples to learn completely the task. Instead, when using transfer learning we start from the weights that have been obtained by training the network on ImageNet. ImageNet contains natural images, so the network has already learned how to extract meaningful features from natural images. Since ImageNet is a very large dataset, the feature extraction that happens in the backbone is capable of extracting good features from the images out of the box (without much training). So we can just substitute the head and train, because the features extracted by the backbone are already enough to assign our custom classes to the images. Sice our dataset is small, but not very small, we can then give a final touch by unfreezing the backbone and training it with a low learning rate.
Applications of Transfer Learning
Transfer learning is widely used in various domains, including:
- Computer Vision: Image classification, object detection, and segmentation.
- Natural Language Processing: Text classification, sentiment analysis, and language translation.
- Speech Recognition: Recognizing and transcribing spoken language.
- Healthcare: Predicting diseases from medical images and patient records
Practical Considerations
- Dataset Compatibility: Ensure that the data distribution of the pre-trained model is somewhat similar to the target task.
- Layer Selection: The choice of which layers to freeze and which to fine-tune can significantly impact performance.
- Hyperparameter Tuning: Fine-tuning requires careful selection of hyperparameters to avoid overfitting or underfitting.
Challenges and Limitations
- Domain Difference: Significant differences between the source and target domains can lead to poor transfer learning performance.
- Overfitting: Over-reliance on the pre-trained model might lead to overfitting on the target dataset if not managed properly.
- Computational Cost: Fine-tuning large models can still be computationally intensive.
Conclusion
Transfer learning is a powerful technique that leverages the strengths of pre-trained models to tackle new tasks efficiently and effectively. By understanding and implementing transfer learning, practitioners can achieve better performance and faster results, especially when dealing with limited data.
Thank you for taking the time to read this article on transfer learning. If you found it helpful and would like to stay updated with more insights, please like and share. Your support is greatly appreciated! Feel free to connect with me on LinkedIn at Ankit Malik