PyTorch Lightning 1.1 - Model Parallelism Training and More Logging Options

Image for post
Image for post
Image by Author

Lightning 1.1 is now available with some exciting new features. Since the launch of V1.0.0 stable release, we have hit some incredible milestones- 10K GitHub stars, 350 contributors, and many new members in our slack community! A few highlights include:

  • Sharded model training- save up to 55% of memory without losing speed
  • Sequential Model Parallelism
  • Automatic logging for callbacks and any LightningModule hook*.
  • Lightning Bolts 0.2.6 release

We're thrilled to introduce the beta version of our new sharded model training plugin, in collaboration with FairScale by Facebook. Sharded Training utilizes Data-Parallel Training under the hood, but optimizer states and gradients are sharded across GPUs. This means the memory overhead per GPU is lower, as each GPU only has to maintain a partition of your optimizer state and gradients. You can use this plugin to reduce memory requirements by up to 60% (!) by simply adding a single flag to your Lightning trainer, with no performance loss.

# install fairscale
pip install https://github.com/PyTorchLightning/fairscale/archive/pl_1.1.0.zip
# train using Sharded DDP
trainer = Trainer(gpus=8, accelerator='ddp', plugins='ddp_sharded')
Image for post
Image for post
Average Peak Memory Training a Transformer LM ((22 layers, hidden size 3072, trained on SST, 2 billion variant with 32 layers), SwAV Wide Resnet (trained on STL-10), DeepSpeech2 (trained on Librispeech100), iGPT (trained on MNIST) using 8 A100s. Uses same hyper-parameters and batch size per model. We increase model capacity to roughly a billion parameters. Lower is better. Image by Author

To learn more about our new sharded training, read this blog.

This release also includes integration for Sequential Model Parallelism from FairScale. Sequential Model Parallelism allows splitting a sequential module onto multiple GPUs according to the preffered balance, reducing peak GPU memory requierements. Furthermore, Model Parallelism supports micro-batches and memory monger for fitting even larger sequential model.

To use Sequential Model Parallelism, you must define a nn.Sequential module that defines the layers you wish to parallelize across GPUs. This should be kept within the sequential_module variable within your LightningModule like below.

Want to give it a try? We provide a minimal example of Sequential Model Parallelism using a convolutional model training on cifar10, split onto GPUs here. Simply run:

pip install pytorch-lightning-boltpython pl_examples/basic_examples/conv_sequential_example.py --batch_size 1024 --gpus 2 --accelerator ddp --use_ddp_sequential

In 1.0 we introduced a new easy way to log any scalar in the training or validation step, using self.log the method. It is now available in all LightningModule or Callback hooks (except hooks for *_batch_start- such as on_train_batch_start or on_validation_batch_start. Use on_train_batch_end/on_validation_batch_end instead!).

Depending on where self.log is called from, Lightning auto-determines the correct logging mode for you (logs after every step in training_step, logs epoch accumulated metrics for every epoch in validation or test steps). But of course, you can override the default behavior by manually setting the log() parameters.

self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

Read more about logging in our docs.

  • MultiClass AUROC metric
  • New API for ConfusionMatrix, PrecisionRecallCurve, ROC, and AveragePrecision class metrics
Image for post
Image for post
Image for post
Image for post
ROC
Image for post
Image for post
PrecisionRecallCurve
Image for post
Image for post
AveragePrecision
  • Added step-index to the checkpoint filename (so filename will be something like epoch=0-step=428.ckpt).
  • Added changeable extension variable for ModelCheckpoint , so you can override the default “.ckpt” extension.
  • Add on_after_backward and on_before_zero_grad hooks to callbacks.
  • Adds the ability to optionally log momentum values in the LearningRateMonitor.
  • DDP now works with manual optimization.

We’d like to thank all the hard working contributors that took part in this release. Kudos! If you want to give back to the community, here’s a list of issues for new contributors you can try to solve.

Let’s meet!

Want to learn more about new features and get inspired by community projects? In our next community meetup were introducing Lightning Talks- 5 projects in 5 minutes, join us on December 17th 1PM EST to learn more about the new model sharded training, self supervised learning for object detection, and how a kaggle grandmaster is using Lightning in his projects! RSVP here.

Interested in presenting in our next meetup? Fill this out! It’s a great way to make connections, spread the word about your work, and help your fellow researchers.

PyTorch

An open source machine learning framework that accelerates…

Thanks to Thomas Chaton

PyTorch Lightning team

Written by

PyTorch Lightning is a deep learning research frameworks to run complex models without the boilerplate.

PyTorch

PyTorch

An open source machine learning framework that accelerates the path from research prototyping to production deployment

PyTorch Lightning team

Written by

PyTorch Lightning is a deep learning research frameworks to run complex models without the boilerplate.

PyTorch

PyTorch

An open source machine learning framework that accelerates the path from research prototyping to production deployment

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store