Building Deep Learning Models using Lightning.ai
This article guides you to get a brief understanding of using PytorchLightning. This article is also a perfect journaling on how I learn something new or something that I learnt in the past and revisit the tool that may change as time goes on.
If you ask why writing another article if the documentation clearly guides the user. Well, I tried and it is a bit confusing and the use case I didn’t expect to learn the library in the first place.
PyTorch Lightning is the deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale. Lightning evolves with you as your projects go from idea to paper/production.
As the name mentioned, this libray is build on top of Pytorch. If you are already comfortable with Pytorch. Congrats! you can understand this library within 15 minutes.
So, the purpose of this library is to cut the boilerplate everytime you train a model using PyTorch. This makes development faster and less code mess for configuring this and that.
We will create an image classification model using ResNet50 and CIFAR10 as the dataset.
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.
Load libraries
import os
import pdb
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
import torchvision.models as models
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import loggers as pl_loggers
import lightning.pytorch as pl
Load dataset
data_set = CIFAR10(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
test_set = CIFAR10(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())
Split dataset for train and dev
train_set_size = int(len(data_set) * 0.8)
valid_set_size = len(data_set) - train_set_size
seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator = seed)
print(f"Train size: {train_set_size}")
print(f"Valid size: {valid_set_size}")
Setup model class
class ImagenetTransferLearning(pl.LightningModule):
def __init__(self, num_target_classes=10):
super().__init__()
backbone = models.resnet50(weights="DEFAULT")
num_filters = backbone.fc.in_features
layers = list(backbone.children())[:-1]
self.feature_extractor = nn.Sequential(*layers)
self.classifier = nn.Linear(num_filters, num_target_classes)
self.criterion = nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
x, y = batch
self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)
y_pred = self.classifier(representations)
loss = self.criterion(y_pred, y)
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)
y_pred = self.classifier(representations)
loss = self.criterion(y_pred, y)
self.log("valid_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
def test_step(self, batch, batch_idx):
x, y = batch
self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)
y_pred = self.classifier(representations)
loss = self.criterion(y_pred, y)
self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
return optimizer
Training model
model = ImagenetTransferLearning()
train_loader = DataLoader(train_set, batch_size = 512)
valid_loader = DataLoader(valid_set, batch_size = 512)
test_loader = DataLoader(test_set, batch_size=512)
tb_logger = pl_loggers.TensorBoardLogger('cifar10_logs/')
trainer = pl.Trainer(max_epochs=10,
default_root_dir="resnet50/",
enable_checkpointing=True,
logger=tb_logger)
trainer.fit(model, train_loader, valid_loader)
trainer.test(model, test_loader)
Visualize training
This article use tensorboard logger, so we can use tensorboard to visualize the metrics. Run this command in command prompt.
tensorboard --logdir cifar10_logs/lightning_logs/
Final words
I hope that this a bunch of code snippets can help you to understand and build deep learning model using PytorchLightning.