Simplify your PyTorch code with PyTorch Lightning

Sandy Bei
Innovation-res
Published in
5 min readMar 24, 2022

An introduction to PyTorch Lightning, a framework for making deep learning model training easier and faster.

What is PyTorch Lightning ?

PyTorch Lightning is a lightweight and high-performance framework built on top of PyTorch that allows you to organize your code and automate the optimization process of training. It also provides the following features:

  • metrics (accuracy, precision, recall, etc.)
  • metrics logging
  • model checkpointing
  • early stoppings
  • training on multiple GPUs, TPUs, CPUs
  • faster implementation (300 ms per epoch compared with pure PyTorch)

The official docs can be found here and the source code on GitHub here.

Installation

Before making any changes, you must install Lightning:

pip install pytorch-lightning

and import it to your code:

import pytorch_lightning as pl

Organize your PyTorch code

There are different ways to use PyTorch Lightning. In this post we will use only the pl.LightningModule to replace the nn.Module module and the training loop.

A LightningModule is a torch.nn.Module but with added functionality. Use it as such!

To use PyTorch Lightning, you must structure your code under the functions of LightningModule. The module and its required functions are shown below:

These 4 functions are the minimum required for training your model with Lightning. Other functions you will probably need to add are: prepare_data(), validation_step(), test_step() and predict_step().

So, the changes you have to make are shown in the picture below:

Steps:

  1. pass pl.LightningModule instead of nn.Module to the module
  2. move all required code under the relevant functions inside the module
  3. you can remove .to(device) — Lightning moves the data coming from the LightningModule to devices automatically

The full code for the above examples can be found here.

For a great side-by-side comparison of PyTorch and PyTorch Lightning, read the following article, written by one of PyTorch Lightning’s creators:

Train your model

If you have structured your code on LightningModule, you can train your model in just 3 lines of code with the use of Trainer — a function built on top of nested loops that run over all batches in a dataloader, over all epochs:

model = NeuralNetworkLit()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_dataloader)

All the following are handled by Trainer:

  • loops over epochs
  • loops over batches
  • grads disabling
  • backward passes
  • optimizers update
  • schedulers
  • metrics computation

Logging 📃

In LightningModule, you can log metrics (at each step or epoch) for a training, validation or test step. To do that, use log() method to the step and metric you want to monitor. This function sends the computed metrics to a logger which then stores them to a default directory in your working directory. To log a metric:

  • pass a key for the quantity you want to monitor (here: “train_loss”)
  • pass the variable that stores that quantity (here: loss)
def training_step(self, batch, batch_idx):
x, y = batch
loss = self.loss_fn(pred, y)
self.log("train_loss", loss)
return loss

By default, log():

  • logs at the end of every step, when placed inside training_step()
  • logs at the end of every epoch, when placed inside validation_step() and test_step()

You can change that through on_step and on_epoch parameters, as shown below:

self.log("train_loss", loss, on_step=False, on_epoch=True)

❗If you want to checkpoint your model and add early stoppings to your training, it is required that log() is added to the metric you want to monitor.

Model Checkpointing 💾

You can automatically save the weights of your model for a training or validation step based on the metric you want to monitor (e.g. accuracy, loss) through ModelCheckpoint. To checkpoint your model:

  1. Import ModelCheckpoint callback:
from pytorch_lightning.callbacks import ModelCheckpoint 

2. Add log() to the metric you want to monitor:

def validation_step(self, batch, batch_idx):
x, y = batch
loss = self.loss_fn(pred, y)
self.log("val_loss", loss)
return loss

3. Create an instance of the ModelCheckpoint class:

checkpoint_callback = ModelCheckpoint(monitor='val_loss',mode='min')
  • pass to the monitor parameter the key of the metric you want to monitor (the string you defined in log())
  • pass to mode parameter ‘min’ or ‘max’ to stop to checkpoint the mode when the quantity monitored has stopped improving (increasing or decreasing)

4. Pass the checkpoint callback Trainer through the callbacks parameter:

trainer = pl.Trainer(max_epochs=5, callbacks=[checkpoint_callback])

Also, to automatically save your model’s hyperparameters, add self.save_hyperparameters() in LightningModule 's __init__(). The model’s hyperparameters will then get stored to self.hparams attribute and will also get stored within the model checkpoint:

def __init__(self, batch, batch_idx):
self.save_hyperparameters()

Below, you can see the lightning_logs folder that gets created by the logger. This is the default directory where the logs, checkpoints and hyperparameters get stored.

Early Stoppings 🛑

You can stop the training of your model at an epoch early with EarlyStopping callback when there is no improvement (decrease or increase) of the monitored metric. To stop a training early:

  1. Import EarlyStopping callback:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

2. Add log() to the metric you want to monitor:

def validation_step(self, batch, batch_idx):
x, y = batch
pred = self.loss_fn(pred, y)
self.log("val_loss", loss)
return loss

2. Create an instance of the EarlyStopping class:

early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=10)
  • pass to the monitor parameter the key of the metric you want to monitor (the string you defined in log())
  • pass to mode parameter ‘min’ or ‘max’ to stop to training when the quantity monitored has stopped improving (increasing or decreasing)
  • pass to patience the number of events to wait until there is no further improvement

3. Pass EarlyStopping callback to Trainer through the callbacks parameter:

trainer = pl.Trainer(max_epochs=5, callbacks=[early_stopping])

Multi-GPU Training ⏭

To train in multiple GPUs, just pass to gpus parameter in Trainer the number of GPUs of your device you want to use:

e.g. for using 2 gpus:

trainer = Trainer(gpus=2)

e.g. to train on all available gpus use gpus=-1:

trainer = Trainer(gpus=-1)

With Lightning you can do many more things not mentioned in this tutorial — such as structuring your data cleaning, processing and splitting with LightningDataModule module, Gradient Clipping, Accumulate Gradients, LR-finder, batch-size finder and more.

Thanks for reading!

References

  1. https://medium.com/r/?url=https%3A%2F%2Fwww.pytorchlightning.ai%2F
  2. https://pytorch-lightning.readthedocs.io/en/latest/starter/converting.html
  3. https://research.aimultiple.com/pytorch-lightning/
  4. https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/
  5. https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09

--

--