PyTorch Lightning — it’s a wrapper !

Ruben Stefanus
Data Folks Indonesia
3 min readJun 28, 2020
PyTorch Lightning

The lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate.

Are you quite familiar with PyTorch for did deep learning problem ?

If YES, you will find out that you always need to do some engineering stuff in every time you do AI research project

The Typical AI Research project

In a research project, we normally want to identify the following key components:

  • the model(s)
  • the data
  • the loss
  • the optimizer(s)

Let’s say we want to breakdown the model component …

In PyTorch, we always need to create loop for train model in each epoch, and do this following step :

Example of training process :

  • Unpack our data inputs and labels
  • Load data to GPU
  • Clear out the gradients calculated in the previous pass.
  • Forward pass (feed input data through the network)
  • Backward pass (backpropagation)
  • Tell the network to update parameters with optimizer.step()
  • Track variables for monitoring progress

All steps that need to be done, it really consumes our time. So, PyTorch Lightning comes up with the idea to become wrapper and handler of engineering stuff and other non-essential stuff like (logging, etc)

Based on the official github of PyTorch Lightning it’s said that :

Lightning is a way to organize your PyTorch code to decouple the science code from the engineering. It’s more of a PyTorch style-guide than a framework.

In Lightning, you organize your code into 3 distinct categories:
# Research code (goes in the LightningModule).
# Engineering code (you delete, and is handled by the Trainer).
# Non-essential research code (logging, etc… this goes in Callbacks).

Here’s an example of how to refactor your research code into a LightningModule.

Refactor your PyTorch code

The rest of the code is automated by the Trainer!

Trainer

What does lightning control for me ?

Everything in Blue! This is how lightning separates the science (red) from engineering (blue).

Then, I also create a template for do MNIST Classifier problem. And you can check it at pl-mnist repository

What do I do for this template ?

  • Add learning rate scheduler and logger
  • Add custom callback class
  • Add early stopping callback
  • Add Model Checkpoint
  • Add TensorBoard for visualisation metric
  • Show how to training, validation, testing, and inferencing the model

Model Class

Custom Callbacks

Prepare MNIST data

Train Model
If you want to train the model use GPU/TPU, you can just add “gpus=1” or “tpu_cores=1” parameter inside Trainer function where ‘1’ is number of GPU/TPU cores

Training process looks like :

training report

TensorBoard

Visualize your metric

References

Thank you for reading this article ~
You can connect with me through this https://rubentea16.github.io/
If you like this article, just give claps and share ! haha :)

--

--