Simplify your PyTorch code with PyTorch Lightning
An introduction to PyTorch Lightning, a framework for making deep learning model training easier and faster.
What is PyTorch Lightning ?
PyTorch Lightning is a lightweight and high-performance framework built on top of PyTorch that allows you to organize your code and automate the optimization process of training. It also provides the following features:
- metrics (accuracy, precision, recall, etc.)
- metrics logging
- model checkpointing
- early stoppings
- training on multiple GPUs, TPUs, CPUs
- faster implementation (300 ms per epoch compared with pure PyTorch)
The official docs can be found here and the source code on GitHub here.
Installation
Before making any changes, you must install Lightning:
pip install pytorch-lightning
and import it to your code:
import pytorch_lightning as pl
Organize your PyTorch code
There are different ways to use PyTorch Lightning. In this post we will use only the pl.LightningModule
to replace the nn.Module
module and the training loop.
A
LightningModule
is atorch.nn.Module
but with added functionality. Use it as such!
To use PyTorch Lightning, you must structure your code under the functions of LightningModule
. The module and its required functions are shown below:
These 4 functions are the minimum required for training your model with Lightning. Other functions you will probably need to add are: prepare_data()
, validation_step()
, test_step()
and predict_step()
.
So, the changes you have to make are shown in the picture below:
Steps:
- pass
pl.LightningModule
instead ofnn.Module
to the module - move all required code under the relevant functions inside the module
- you can remove
.to(device)
— Lightning moves the data coming from theLightningModule
to devices automatically
The full code for the above examples can be found here.
For a great side-by-side comparison of PyTorch and PyTorch Lightning, read the following article, written by one of PyTorch Lightning’s creators:
Train your model
If you have structured your code on LightningModule
, you can train your model in just 3 lines of code with the use of Trainer
— a function built on top of nested loops that run over all batches in a dataloader, over all epochs:
model = NeuralNetworkLit()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_dataloader)
All the following are handled by Trainer
:
- loops over epochs
- loops over batches
- grads disabling
- backward passes
- optimizers update
- schedulers
- metrics computation
Logging 📃
In LightningModule
, you can log metrics (at each step or epoch) for a training, validation or test step. To do that, use log()
method to the step and metric you want to monitor. This function sends the computed metrics to a logger which then stores them to a default directory in your working directory. To log a metric:
- pass a key for the quantity you want to monitor (here:
“train_loss”
) - pass the variable that stores that quantity (here:
loss
)
def training_step(self, batch, batch_idx):
x, y = batch
loss = self.loss_fn(pred, y)
self.log("train_loss", loss)
return loss
By default, log()
:
- logs at the end of every step, when placed inside
training_step()
- logs at the end of every epoch, when placed inside
validation_step()
andtest_step()
You can change that through on_step
and on_epoch
parameters, as shown below:
self.log("train_loss", loss, on_step=False, on_epoch=True)
❗If you want to checkpoint your model and add early stoppings to your training, it is required that log()
is added to the metric you want to monitor.
Model Checkpointing 💾
You can automatically save the weights of your model for a training or validation step based on the metric you want to monitor (e.g. accuracy, loss) through ModelCheckpoint
. To checkpoint your model:
- Import
ModelCheckpoint
callback:
from pytorch_lightning.callbacks import ModelCheckpoint
2. Add log()
to the metric you want to monitor:
def validation_step(self, batch, batch_idx):
x, y = batch
loss = self.loss_fn(pred, y)
self.log("val_loss", loss)
return loss
3. Create an instance of the ModelCheckpoint
class:
checkpoint_callback = ModelCheckpoint(monitor='val_loss',mode='min')
- pass to the
monitor
parameter the key of the metric you want to monitor (the string you defined inlog()
) - pass to
mode
parameter‘min’
or‘max’
to stop to checkpoint the mode when the quantity monitored has stopped improving (increasing or decreasing)
4. Pass the checkpoint callback Trainer
through the callbacks
parameter:
trainer = pl.Trainer(max_epochs=5, callbacks=[checkpoint_callback])
Also, to automatically save your model’s hyperparameters, add self.save_hyperparameters()
in LightningModule
's __init__()
. The model’s hyperparameters will then get stored to self.hparams
attribute and will also get stored within the model checkpoint:
def __init__(self, batch, batch_idx):
self.save_hyperparameters()
Below, you can see the lightning_logs folder that gets created by the logger. This is the default directory where the logs, checkpoints and hyperparameters get stored.
Early Stoppings 🛑
You can stop the training of your model at an epoch early with EarlyStopping
callback when there is no improvement (decrease or increase) of the monitored metric. To stop a training early:
- Import
EarlyStopping
callback:
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
2. Add log()
to the metric you want to monitor:
def validation_step(self, batch, batch_idx):
x, y = batch
pred = self.loss_fn(pred, y)
self.log("val_loss", loss)
return loss
2. Create an instance of the EarlyStopping
class:
early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=10)
- pass to the
monitor
parameter the key of the metric you want to monitor (the string you defined inlog()
) - pass to
mode
parameter‘min’
or‘max’
to stop to training when the quantity monitored has stopped improving (increasing or decreasing) - pass to
patience
the number of events to wait until there is no further improvement
3. Pass EarlyStopping
callback to Trainer
through the callbacks
parameter:
trainer = pl.Trainer(max_epochs=5, callbacks=[early_stopping])
Multi-GPU Training ⏭
To train in multiple GPUs, just pass to gpus
parameter in Trainer
the number of GPUs of your device you want to use:
e.g. for using 2 gpus:
trainer = Trainer(gpus=2)
e.g. to train on all available gpus use gpus=-1
:
trainer = Trainer(gpus=-1)
With Lightning you can do many more things not mentioned in this tutorial — such as structuring your data cleaning, processing and splitting with LightningDataModule
module, Gradient Clipping, Accumulate Gradients, LR-finder, batch-size finder and more.
Thanks for reading!
References
- https://medium.com/r/?url=https%3A%2F%2Fwww.pytorchlightning.ai%2F
- https://pytorch-lightning.readthedocs.io/en/latest/starter/converting.html
- https://research.aimultiple.com/pytorch-lightning/
- https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/
- https://towardsdatascience.com/from-pytorch-to-pytorch-lightning-a-gentle-introduction-b371b7caaf09