PyTorch Lightning V1.2.0- DeepSpeed, Pruning, Quantization, SWA

New release including many new PyTorch integrations, DeepSpeed model parallelism, and more.

PyTorch Lightning team
PyTorch
5 min readFeb 19, 2021

--

We are happy to announce PyTorch Lightning V1.2.0 is now publicly available. It is packed with new integrations for anticipated features such as:

Continue reading to learn more about what’s available. As always, feel free to reach out on Slack or discussions for any questions you might have or issues you are facing.

PyTorch Profiler [BETA]

PyTorch Autograd provides a profiler that lets you inspect the cost of different operations

inside your model — both on the CPU and GPU (read more about the profiler in the PyTorch documentation). You can now enable the PyTorch profiler in Lightning out of the box:

Or initialize the profiler for further customization:

Example report:

Learn about all the Lightning supported profilers here.

DeepSpeed Plugin [BETA]

DeepSpeed offers additional CUDA Deep Learning training optimizations, to train massive billion-parameter models. DeepSpeed offers lower-level training optimizations such as ZeRO-Offload, and useful memory/speed efficient optimizers such as 1-bit Adam. We’ve recorded 10+ Billion Parameter models using our default training configuration on multiple GPUs, with follow-up technical details coming soon.

To enable DeepSpeed in Lightning 1.2 simply pass in plugins='deepspeed' to your Lightning trainer (docs).

Learn more about DeepSpeed implementation with technical publications here.

Pruning [BETA]

Pruning is a technique to optimize model memory, hardware, and energy requirements by eliminating some of the model weights. Pruning is able to achieve significant model efficiency improvements while minimizing the drop in task performance. The pruned model is smaller in size and faster to run.

To enable pruning during training in Lightning 1.2, simply pass in the ModelPruning callback to the Lighting Trainer (using torch pruning under the hood).

This callback supports multiple pruning functions (pass any torch.nn.utils.prune function as a string to select which weights to pruned), setting pruning percentage, performing iterative pruning, and applying the lottery ticket hypothesis, and more (docs).

Quantization [BETA]

Model quantization is another performance optimization technique that allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating-point precision. Quantization not only reduces the model size but also speeds up loading since operations on fixpoint are faster than on floating-point.

Quantization Aware Training (QAT) mimics the effects of quantization during training: all computations are carried out in floating points while training, simulating the effects of ints, and weights and activations are quantized into lower precision only once training is completed.

Lightning 1.2 includes Quantization Aware Training callback (using PyTorch native quantization, read more here), which allows creating fully quantized models (compatible with torchscript).

You can further customize the callback:

Read the docs here.

Stochastic Weight Averaging [BETA]

Stochastic Weight Averaging (SWA) can make your models generalize better at virtually no additional cost. This can be used with both non-trained and trained models. The SWA procedure smooths the loss landscape thus making it harder to end up in a local minimum during optimization.

Lightning 1.2 supports SWA (using PyTorch native implementation), with a simple trainer flag (available with PyTorch version 1.6 and higher)

Or for further customization use the StochasticWeightAveraging callback:

Read the docs here.

Finetuning [BETA]

Finetuning (or transfer-learning) is the process of tweaking a model trained on a large dataset, to a particular (likely much smaller) dataset. For more details on finetuning, see this Flash notebook.

To make finetuning simpler with Lightning, we are introducing BackboneFinetuning callback you can customize for your own use case, or create your own callback, subclassing BaseFinetuning:

PyTorch Geometric integration

PyTorch Geometric (PyG) is a popular deep learning geometric extension library for PyTorch (see Fast Graph Representation Learning with PyTorch Geometric by Matthias Fey and Jan E. Lenssen). Currently, PyG provides over 60+ SOTA models and methods for Graph Convolution. You can now train PyG models with the Lightning Trainer! See examples here.

New Accelerator/plugins API

Training Deep Learning Models at scale while retaining full flexibility requires a lot of orchestration between different responsibilities. The new API isolates responsibilities by introducing a new accelerator API as well as new types of Plugins: one for different training types (like a single device, DDP, …) and one to handle different floating-point precisions during training. Having dedicated interfaces reduces code duplication and enhances usability (for power users).

The Trainer interface did not change for most use-cases but was extended to allow further customization through plugins.

Simple accelerator use:

Or pass in a plugin for customization:

For help migrating custom plugins to the new API reach out to us on slack or via support@pytorchlightning.ai.

Special shout out to Justus Schock and Adrian Wälchli for all the hard work!

Other improvements

  • Added support for multiple train loaders
  • adding trainer.predict for simple inference with Lightning
  • New metrics: HammingDistance, StatScores, R2Score
  • Added LightningModule.configure_callbacks to enable the definition of model-specific callbacks
  • Enabled self.log in callbacks
  • Changed the seq of on_train_batch_end, on_batch_end & on_train_epoch_end, on_epoch_end hooks

See all changes in the release notes.

Thank you!

Big kudos to all the community members for their contributions and feedback. We now have over 400 Lightning contributors! Want to give open source a try and get free Lightning swag? We have a #new_contributors channel on slack. Check it out!

--

--

PyTorch
PyTorch

Published in PyTorch

An open source machine learning framework that accelerates the path from research prototyping to production deployment

PyTorch Lightning team
PyTorch Lightning team

Written by PyTorch Lightning team

We are the core contributors team developing PyTorch Lightning — the deep learning research framework to run complex models without the boilerplate

Responses (1)