Warp speed model training in PADL with PyTorch-Lightning

Sebastian Lee
PADL Developer Blog
4 min readMar 9, 2022

--

PADL + PyTorch-Lightning provides additional convenience, comfort and speed
PADL + PyTorch-Lightning provides additional convenience, comfort and speed

PyTorch-Lightning is a package for PyTorch which delivers flexible routines for training models incorporating the best and latest practices for deep learning optimization, variable precision, and toggle-on/ off switches allowing easy scaling to multi-GPU and multi-node parallelism. This allows the data scientist to concentrate on the science, and reduces boilerplate associated with training. In this post, we discuss how to leverage this power together with PADL, to get an additional boost into hyper-space.

PADL is a functional model builder for PyTorch allowing for full export of model pipelines including pre-processing, forward pass and post-processing. In addition, PADL offers some super handy usability features, such as operator pipeline building via operator overloading, interactive/ notebook friendly design and tools for pipeline inspection and debugging. Follow us on GitHub, and read more about PADL in the official docs and on the developer’s blog. You can try the full notebook on Colab!

We can use the full range of PyTorch functionality with PADL. We’ll also import pytorch_lightning and some connectors from the PADL-extensions package padl_ext. PADL also allows you to easily incorporate components from the entire Python ecosystem in your pipeline - for instance, numpy.

For simplicity, we’ll use MNIST data in this tutorial. You can use any data sets with PADL.

We’ll be using a simple convolution net on greyscale images. We wrap the class definition with the decorator @transform which allows the PyTorch layer to use all of the cool PADL functionality, while also profiting from the usual PyTorch features.

We can now build the hybrid PyTorch/ PADL object into some pipelines, which we’ll use to train and test the layer. PADL makes use of operator overloading, which makes it fun and simple to combine PADL transforms and pipelines. >> means to compose components ("transforms" and "pipelines"), / means to apply components in parallel. See more here.

The output of the train-pipeline is a scalar loss (padl.transform(F.cross_entropy)). This differs from how you might normally write your training. Usually in PyTorch you would write a dataset, then wrap it in a dataloader, then define a layer and loss. In the training loop you'd fetch batches from the dataloader, push them through the model, and then evaluate the loss on the outputs. In PADL this logic is handled inside the pipeline. That makes it super easy to define very useful objects for talking to the PyTorch ecosystem, as we do here with PyTorch Lightning.

Let’s test the pipeline on a single data point. To do that we’ll use .infer_apply. There's also .eval_apply and .train_apply which allow batching (see this fully worked out example of pure PADL training).

We can define an auxiliary model with weights tied to train_model simply by reusing the layer. This auxiliary model is useful in practice, since it outputs predictions as raw floats. This can then be plugged straight into whatever production environment you want. We could have also created a JSON output or whatever else we liked.

The PADL extensions package contains a PyTorch Lightning plugin. This is a very lightweight extension of the PyTorch Lightning module. In the usual way with PyTorch Lightning, we can extend functionality by overwriting the default methods. For example, theMyModule instance is a PyTorch Lightning module, and also a PADL object. This has some very handy advantages when saving the results of training.

The difference of using PadlLightning compared to the standard PyTorch Lightning case, is that the amount that needs to be defined is drastically reduced. Due to the structure of the PADL pipelines, methods usually defined manually may be determined automatically. That means these methods don't need to be defined:

  • train_dataloader
  • valid_dataloader
  • test_dataloader
  • train_step
  • valid_step
  • test_step
  • on_save_checkpoint

In fact, in the majority of cases, the PadlLightning object may be used directly out of the box.

If you prefer to also save an inference model rather than just the training model, then this may also be passed to the PadlLightning Module. That means that the training model will be used to compute losses, and the inference model’s weights will track those weights, and be saved whenever PyTorch Lightning monitoring determines that a good model has been found. All of the standard PyTorch Lightning functionality may be used as with a standard PyTorch Lightning trainer. This comes in very handy, when the way the layer is used in inference is very different from the way it’s used in training (think beam search in neural translation).

Let’s fit the module on the data.

Now let’s try a few predictions!

Now we can make use of a major practical advantage of PADL: saving and loading are completely self contained and take care of all aspects of the pipeline. That means that the following cell works in a completely new session! You can restart the kernel, do this in a new session/ server etc. That makes the results which you obtained with the PyTorch lightning trainer super portable and reusable!

Let’s verify the predictions of the loaded model

Continuing training is effortless!

PADL plus PyTorch Lightning means in the future that your design patterns can take a very satisfying conventional form:

  1. Define pre-processing, forward pass, post-processing, and loss into one or two PADL pipelines.
  2. Pass the pipeline to the PadlLightning trainer and start training.
  3. Save the trainer (which saves the contained pipelines).
  4. Reload the trainer in one line of code.
  5. Continue the training on the updated/ latest data.
  6. Resave the trainer.
  7. Rinse and repeat previous steps ad infinitum.

And that’s all there is to it. Nice pipeline definitions in PADL trained at warp speed with PyTorch Lightning.

Happy PADL-ling!

Do you want to know more about PADL? Here are some resources:

--

--