Building Deep Learning Models using Lightning.ai

AC
Data Folks Indonesia
3 min readJun 17, 2023
Photo by Holly Mandarich on Unsplash

This article guides you to get a brief understanding of using PytorchLightning. This article is also a perfect journaling on how I learn something new or something that I learnt in the past and revisit the tool that may change as time goes on.

If you ask why writing another article if the documentation clearly guides the user. Well, I tried and it is a bit confusing and the use case I didn’t expect to learn the library in the first place.

PyTorch Lightning is the deep learning framework for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale. Lightning evolves with you as your projects go from idea to paper/production.

As the name mentioned, this libray is build on top of Pytorch. If you are already comfortable with Pytorch. Congrats! you can understand this library within 15 minutes.

So, the purpose of this library is to cut the boilerplate everytime you train a model using PyTorch. This makes development faster and less code mess for configuring this and that.

We will create an image classification model using ResNet50 and CIFAR10 as the dataset.

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

Load libraries

import os
import pdb
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10
import torchvision.models as models
from torch.utils.data import DataLoader, random_split
from pytorch_lightning import loggers as pl_loggers
import lightning.pytorch as pl

Load dataset

data_set = CIFAR10(os.getcwd(), download=True, train=True, transform=transforms.ToTensor())
test_set = CIFAR10(os.getcwd(), download=True, train=False, transform=transforms.ToTensor())

Split dataset for train and dev

train_set_size = int(len(data_set) * 0.8)
valid_set_size = len(data_set) - train_set_size

seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator = seed)

print(f"Train size: {train_set_size}")
print(f"Valid size: {valid_set_size}")

Setup model class

class ImagenetTransferLearning(pl.LightningModule):
def __init__(self, num_target_classes=10):
super().__init__()

backbone = models.resnet50(weights="DEFAULT")
num_filters = backbone.fc.in_features
layers = list(backbone.children())[:-1]
self.feature_extractor = nn.Sequential(*layers)

self.classifier = nn.Linear(num_filters, num_target_classes)
self.criterion = nn.CrossEntropyLoss()

def training_step(self, batch, batch_idx):
x, y = batch

self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)

y_pred = self.classifier(representations)

loss = self.criterion(y_pred, y)
self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)

return loss

def validation_step(self, batch, batch_idx):
x, y = batch

self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)

y_pred = self.classifier(representations)

loss = self.criterion(y_pred, y)
self.log("valid_loss", loss, prog_bar=True, on_step=False, on_epoch=True)

return loss

def test_step(self, batch, batch_idx):
x, y = batch

self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)

y_pred = self.classifier(representations)

loss = self.criterion(y_pred, y)
self.log("test_loss", loss, prog_bar=True, on_step=False, on_epoch=True)

return loss

def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
return optimizer

Training model

model = ImagenetTransferLearning()

train_loader = DataLoader(train_set, batch_size = 512)
valid_loader = DataLoader(valid_set, batch_size = 512)
test_loader = DataLoader(test_set, batch_size=512)

tb_logger = pl_loggers.TensorBoardLogger('cifar10_logs/')

trainer = pl.Trainer(max_epochs=10,
default_root_dir="resnet50/",
enable_checkpointing=True,
logger=tb_logger)

trainer.fit(model, train_loader, valid_loader)

trainer.test(model, test_loader)

Visualize training

This article use tensorboard logger, so we can use tensorboard to visualize the metrics. Run this command in command prompt.

tensorboard --logdir cifar10_logs/lightning_logs/

Final words

I hope that this a bunch of code snippets can help you to understand and build deep learning model using PytorchLightning.

--

--