PyTorch Lightning 1.1 - Model Parallelism Training and More Logging Options

  • Sharded model training- save up to 55% of memory without losing speed
  • Sequential Model Parallelism
  • Automatic logging for callbacks and any LightningModule hook*.
  • Lightning Bolts 0.2.6 release

Sharded model training [BETA]

# install fairscale
pip install
# train using Sharded DDP
trainer = Trainer(gpus=8, accelerator='ddp', plugins='ddp_sharded')
Average Peak Memory Training a Transformer LM ((22 layers, hidden size 3072, trained on SST, 2 billion variant with 32 layers), SwAV Wide Resnet (trained on STL-10), DeepSpeech2 (trained on Librispeech100), iGPT (trained on MNIST) using 8 A100s. Uses same hyper-parameters and batch size per model. We increase model capacity to roughly a billion parameters. Lower is better. Image by Author

Pipeline model sharding [BETA]

pip install pytorch-lightning-boltpython pl_examples/basic_examples/ --batch_size 1024 --gpus 2 --accelerator ddp --use_ddp_sequential

Automatic logging everywhere

self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

More improvements

  • MultiClass AUROC metric
  • New API for ConfusionMatrix, PrecisionRecallCurve, ROC, and AveragePrecision class metrics
  • Added step-index to the checkpoint filename (so filename will be something like epoch=0-step=428.ckpt).
  • Added changeable extension variable for ModelCheckpoint , so you can override the default “.ckpt” extension.
  • Add on_after_backward and on_before_zero_grad hooks to callbacks.
  • Adds the ability to optionally log momentum values in the LearningRateMonitor.
  • DDP now works with manual optimization.

