Multi-Node Multi-GPU Comprehensive Working Example for PyTorch Lightning on AzureML

Joel Stremmel
14 min readOct 20, 2021

--

Image 0: Multi-node multi-GPU cluster example

Objectives

This blogpost provides a comprehensive working example of training a PyTorch Lightning model on an AzureML GPU cluster consisting of multiple machines (nodes) and multiple GPUs per node. The code provided enables training on millions of data samples with thousands of features. While there are helpful examples of multi-node training in the PyTorch Lightning and AzureML documentation, this example provides critical, missing information, demonstrating how to:

1. Train on more data than will fit in node memory
2. Connect to and pull data from Azure blob storage
3. Configure environment variables for distributed data parallel training with Open MPI, NCCL, and InfiniBand
4. Build a consistent Python environment for your training pipeline with Docker
5. Structure your distributed training code
6. Submit your training script and validate that your model properly trains

Background and Motivation

Deep learning models and the datasets used to train them are getting bigger. At the time of writing, the largest models like GPT3 and Megatron-Turing NLG have billions of parameters and are trained on billions of words. PyTorch Lightning provides a Python API for distributing and training deep learning models across multiple machines. Distributed training can drastically reduce the time it takes to train on large datasets by running the forward and backward passes of a deep learning model in parallel for each GPU in a cluster, processing as many batches of data as there are GPUs in the cluster at the same time.

I am a principal machine learning engineer at UnitedHealth Group R&D focused on building models to represent, classify, and extract information from medical documents. At UnitedHealth Group our goal is to make the healthcare system more efficient and work better for everyone. We use insights from data to improve the quality and affordability of care. On my team we use deep learning models to automate the otherwise manual process of information retrieval from text to ensure high-quality patient documentation, uncover patterns, and provide downstream information to models used for population health management. Distributed training helps our models scale to the millions of patients for whom we provide insurance and care.

In the steps to follow, I will walk through the process of training a distributed deep learning model on AzureML with PyTorch Lightning in the same way we do at UnitedHealth Group. We hope this blogpost is instructive for and helps advance the work of other data scientists and ML engineers using deep learning to innovate and solve the challenging problems of our time.

Section 1: Train on more data than will fit in memory

Distributed deep learning allows us to train bigger models on more data than what is possible with a single machine. In this blogpost, I will focus specifically on scaling training to large datasets and steps required to make distributed training with PyTorch Lightning work on AzureML. For information on how to scale model size with PyTorch Lightning, read more here.

The more examples a model sees, the more patterns it can uncover. For this comprehensive example, I use two of Azure’s Standard_ND40rs_v2 machines each with 672 GB RAM and eight GPUs. To train your model faster, you can use as many machines with as many GPUs as you like, but just as GPU memory constrains the batch size when training a neural network, node RAM constrains the size of the training set, unless you partition your data as described below.

The following YAML file (save this as params.yml) specifies parameters for our distributed training run, including how much randomly generated data to use for our training job. You can test creating a large amount of data partitioned into many files with the num_records_per_train_file and num_train_files parameters. The number of features per sample (num_features) will also impact how much data you can load into memory at once during distributed training. The code below defaults to 1,000 files with 10,000 samples per file for a total of 10,000,000 samples. Structuring our training data this way allows us to load partitions of our data into memory one file at a time, which we will see how to do in the next section. When using params.yml, be sure to modify the Azure workspace, resource group, and cluster name variables for your Azure subscription.

# Run parameters for training a PyTorch Lightning model on AzureML# Number of nodes in cluster
nodes: 2
# Number of GPUs per node
gpus: 8
# Total number of train partitions model will see (one per epoch)
epochs: 10000
# For ddp, effective batch size is batch_size * gpus * num_nodes
batch_size: 1
# Replace with your workspace name
ws_name: '<workspace-name>-aml'
# Replace with your resource group by providing the cluster name
resource_group: '<workspace-name>-common'
# Replace with the name of your cluster
compute_target: '<cluster-name>'
# The script to train the model
job_script_name: 'train.py'
# Name the experiment whatever you want
exp_name: 'pl-aml-mnmgpu'
# Register a File Dataset in AzureML and provide the name
dataset_name: 'pl-aml-mnmgpu-test'
# Dataset location
dataset_path: '/mnt/azureblobshare/pl-aml-mnmgpu-test'
# Train file partitions
train_files: '/mnt/azureblobshare/pl-aml-mnmgpu-test/train_records/'
# One file of validation data
val_file: '/mnt/azureblobshare/pl-aml-mnmgpu-test/val.pt'
# AzureML's outputs directory
model_output_dir: './outputs/model_checkpoints'
# Number of features for the dataset
num_features: 1000
# Number of output classes
num_classes: 1
# Number of train file partitions
num_train_files: 1000
# Total number of training samples to generate across partitions
num_records_per_train_file: 10000
# Number of validation samples to generate for the validation file
num_val_records: 10000

To create records with these parameters, save and run a script called create_data.py from a machine with access to your Azure blob storage. Modify the paths in params.yml as necessary to point to your azureblobshare file system.

"""
Create files for file dataset on azureblobshare.
"""

import os
import yaml
import torch
import pickle
def generate_data(n, d, c): return {
"features": torch.randn(n, d),
"labels": torch.randint(0, c + 1, (n, 1)).float(),
}
def main():
# Load run parameters
with open("params.yml", "r") as stream:
PARAMS = yaml.safe_load(stream)

# Make train records output directory if non-existent
if not os.path.exists(PARAMS["train_files"]):
os.makedirs(PARAMS["train_files"])
# Create train files
for i in range(PARAMS["num_train_files"]):
data = generate_data(
n=PARAMS["num_records_per_train_file"],
d=PARAMS["num_features"],
c=PARAMS["num_classes"]
)
with open(PARAMS["train_files"] + f"data{i}.pt", "wb") as f:
pickle.dump(data, f)
# Create val file
data = generate_data(
n=PARAMS["num_val_records"],
d=PARAMS["num_features"],
c=PARAMS["num_classes"]
)
with open(PARAMS["val_file"], "wb") as f:
pickle.dump(data, f)
if __name__ == "__main__": main()

To load the data into the memory of our Standard_ND40rs_v2 nodes one partition at a time and ensure that each GPU in the cluster receives different samples from each data partition, we use a PyTorch Lightning Data Module designed to load one partition per epoch. This module represents the feature and label records from create_data.py as a PyTorch dataset. At each new training epoch, it will load a new data partition, and, under-the hood, wrap a PyTorch dataloader in a distributed sampler, so that each record from the data partition is used only once. As the number of GPUs in the cluster increases, the rate at which we process the data will also increase, as the cluster will more quickly exhaust the samples from each training partition. Note that the validation dataset is kept as a single data partition which is reasonable if the validation set constitutes an i.i.d. sample of sufficient size. For brevity, a test dataloader is omitted, but trained models should be evaluated on a final test set unseen during training. The PyTorch Lightning Trainer has a .test method that can use the exact same data module as the .fit method which we will use later.

"""
Script: data.py
About:
Defines a PyTorch dataset for file partitions and a lightning data module to load datasets of new file partitions at each epoch.
"""
import os
import torch
import pickle
import pytorch_lightning as pl
from typing import Optional
from torch.utils.data import Dataset, DataLoader
class PartitionDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self): return len(self.data["features"])
def __getitem__(self, idx): return (
self.data["features"][idx],
self.data["labels"][idx]
)
class PartitionPerEpochDataModule(pl.LightningDataModule):
def __init__(
self, batch_size, train_files, val_file, num_workers=2
):
super().__init__()
self.train_files = sorted(train_files)
self.val_file = val_file
self.batch_size = batch_size
self.num_workers = num_workers
self.val_data = self.load_data(self.val_file)
def load_data(self, file): with open(file, "rb") as f:
data = pickle.load(f)
return data
def prepare_data(self): pass
def setup(self, stage: Optional[str] = None):
"""
Anything called here is being distributed across GPUs
(do many times). Lightning handles distributed sampling.
"""
# Build the val dataset
self.val_dataset = PartitionDataset(data=self.val_data)
def train_dataloader(self):
"""
This function sends the same file to each GPU and
loops back after running out of files.
Lightning will apply distributed sampling to
the data loader so that each GPU receives
different samples from the file until exhausted.
"""
# Load the data file with the right index
total = len(self.train_files)
train_file_idx = self.trainer.current_epoch % total
train_file = self.train_files[train_file_idx]
train_data = self.load_data(train_file)
# Build the train dataset
train_dataset = PartitionDataset(data=train_data)
# Return the dataloader, which lightning will turn
# into a distributed data loader, ensuring that
# different samples are selected on each GPU
return DataLoader(
train_dataset,
self.batch_size,
num_workers=self.num_workers,
pin_memory=True
)
def val_dataloader(self): return DataLoader(
self.val_dataset,
self.batch_size,
num_workers=self.num_workers,
pin_memory=True
)

Section 2: Connect to and pull data from Azure blob storage

After creating the records for training, register them as a FileDataset in AzureML. In the datasets tab in the AzureML UI, select the option to create a new dataset from a datastore:

Image 1: Create dataset from datastore

Next, select the option to create a file dataset:

Image 2: Create file dataset

Finally, provide a path to the records on your azureblobshare file system. Where it says "Select or search by name" you can specify the storage account for your workspace with <workspace-name>-stocont which should be an option in the dropdown. If following the naming conventions from params.yml, replace path_to_dataset/** below with pl-aml-mnmgpu-test/**. You can use whatever path you like if it’s a subdirectory of /mnt/azureblobshare/which should be omitted when typed into the Path field. For example, if your data is located at /mnt/azureblobshare/myfiles/data/, specify myfiles/data/** below.

Image 3: Provide path to file dataset

For more information on creating AzureML datasets, see the dataset documentation.

Section 3: Configure environment variables for distributed data parallel training with Open MPI, NCCL, and InfiniBand

AzureML is compatible with several communication frameworks on the backend. In this blogpost, we will use MPI, a high-performance message passing library. MPI and other communication backends require several environment variables to be set. You can check out the AzureML distributed training documentation for more information, but to use MPI with the package versions and docker image referenced in section 4, you can simply use the functions below by writing a Python module called environment.py which we will reference later. MPI, InfiniBand, and the NVIDIA Collective Communication Library (NCCL) work together to enable and manage fast communication between nodes in the cluster.

"""
Script: environment.py
About:
Configures and documents a distributed AzureML environment.
"""

import os
import torch
import pytorch_lightning as pl
def print_dl_library_versions(): print(f"CUDNN version: {torch.backends.cudnn.version()}")
print(f"Torch version: {torch.__version__}")
print(f"PyTorch Lightning version: {pl.__version__}")
def configure_multi_node_environment(nodes): os.environ["MASTER_ADDR"] = os.environ["AZ_BATCHAI_MPI_MASTER_NODE"]
os.environ["MASTER_PORT"] = "6105"
os.environ["NODE_RANK"] = os.environ["OMPI_COMM_WORLD_RANK"]
print(f"Configuring environment for {nodes} nodes.")
print("MASTER_ADDR = {}".format(os.environ["MASTER_ADDR"]))
print("MASTER_PORT = {}".format(os.environ["MASTER_PORT"]))
print("NODE_RANK = {}".format(os.environ["NODE_RANK"]))

Section 4: Build a consistent Python environment for your training pipeline with Docker

Reproducibility is essential in any scientific field, especially data science. To ensure that we can recreate experimental results from our model, use our model on different compute platforms, and deploy our model for production inference, we need to make sure the Python environment used to train our model can be reproduced. In this section we introduce the file that will manage all of our Python package dependencies, commands to build a Python environment with the conda package manager, and the script that will pull a docker image from Azure, install our dependencies, mount our dataset, and submit our training job.

Create a file called requirements.txt containing the Python packages used across the various Python files described in this blogpost. By specifying the package versions, we guarantee that, even if one or more of these open-source packages are modified in the future such that they no longer are compatible with each other, we will always be able to re-run our distributed training job using these exact package versions.

PyYAML==5.4.1
PyJWT==2.2.0
torch==1.9.1
pytorch-lightning==1.4.9
azureml-core==1.33.0
azureml-dataset-runtime==1.34.0

To build a Python environment from these dependencies, make sure conda is installed, then run the following commands on the machine from which you intend to submit the training job (likely your laptop).

conda create --verbose --yes --name pl-aml-mnmgpu python=3.7.9
source activate pl-aml-mnmgpu
pip install -r requirements.txt

Finally, we need a script to build this environment on the nodes in our cluster, mount our registered dataset to these nodes, and submit a training job. Create a file called submit.py. This file is our orchestrator, calling all the steps required to run distributed training.

"""
Script: submit.py
About:
Submit the multi-node multi-GPU PyTorch Lightning training script on an AzureML cluster according to the parameters in the params.yml file.
"""
# Imports
import os
import yaml
from azureml.core import (
ScriptRunConfig,
Workspace,
Environment,
Experiment,
Dataset,
)
from azureml.core.runconfig import MpiConfiguration
def main():
# Load run parameters
with open("params.yml", "r") as stream:
PARAMS = yaml.safe_load(stream)
# Define Azure subscription and workspace.
# You might need to replace os.environ["SUB_ID"]
# with your Azure subscription ID or an
# environment variable referencing this ID.
ws = Workspace.get(
name=PARAMS["ws_name"],
subscription_id=os.environ["SUB_ID"],
resource_group=PARAMS["resource_group"],
)
# Define Python environment from requirements file
myenv = Environment.from_pip_requirements(
name=PARAMS["exp_name"] + "_env",
file_path="requirements.txt"
)
# Configure VMs in cluster with base image from Microsoft
myenv.docker.enabled = True
myenv.docker.base_image = (
"mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.0.3-cudnn8-ubuntu18.04"
)
# Attach to dataset created from Azureblobstore
ds = Dataset.get_by_name(
workspace=ws, name=PARAMS["dataset_name"]
).as_mount(path_on_compute=PARAMS["dataset_path"])
# Set distributed run parameters
distributed_job_config = MpiConfiguration(
node_count=PARAMS["nodes"]
)
# Define job to run on AML compute
src = ScriptRunConfig(
source_directory=".",
script=PARAMS["job_script_name"],
compute_target=PARAMS["compute_target"],
distributed_job_config=distributed_job_config,
environment=myenv,
arguments=[ds],
)
# Define experiment
experiment = Experiment(workspace=ws, name=PARAMS["exp_name"])
# Submit experiment run
run = experiment.submit(config=src)
run.wait_for_completion(show_output=True)
if __name__ == "__main__":
main()

Section 5: Structure your distributed training code

We’ve been organizing the code to manage and run our distributed training job along the way. At this point we only need two more files: model.py and train.py. After completing this section, check that the following files exist:

  • create_data.py
  • data.py
  • environment.py
  • model.py
  • params.yml
  • requirements.txt
  • submit.py
  • train.py

The reason we’re here is model.py. This could be a convolutional neural network for images, a language model for text, a feed forward neural network for structured data, or any other architecture that conforms to the PyTorch Lightning Module format. When applying this code to your use case, much can remain the same, but you will likely want to replace 1) the randomly generated samples with your own data 2) the BoringModel used here with an interesting model appropriate for the problem on which you’re working. The boring model below is useful for debugging distributed training.

"""
Script: model.py
About:
Defines a basic PyTorch Lightning model.
"""
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
class BoringModel(pl.LightningModule):
"""
A very boring model from:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report_model.py
with the necessary methods for lightning training.
"""
def __init__(self, num_features, num_classes): super().__init__()
self.save_hyperparameters()
self.num_features = num_features
self.num_classes = num_classes
self.layer = torch.nn.Linear(
self.num_features, self.num_classes
)
def forward(self, x): return self.layer(x)
def training_step(self, batch, batch_idx): data, labels = batch
yhat = self.forward(data)
train_loss = F.binary_cross_entropy_with_logits(
yhat, labels
)
self.log("train_loss", train_loss, on_epoch=True)
return train_loss
def validation_step(self, batch, batch_idx): data, labels = batch
yhat = self.forward(data)
val_loss = F.binary_cross_entropy_with_logits(
yhat, labels
)
self.log("val_loss", val_loss, on_epoch=True)
def configure_optimizers(self): return torch.optim.SGD(self.layer.parameters(), lr=0.1)

The following script calls functions and classes from the other files from the blogpost to set environment variables, build the data module, and fit the model. The script creates an instance of the PyTorch Lightning Trainer class and uses it to run the forward and backward passes that train the model. The trainer also applies checkpointing and early stopping to save copies of the model at each epoch and stop training when validation loss degrades. Note that the data loaders are reloaded every epoch to make sure we always train on a new data partition. In PyTorch Lightning 1.4, it’s possible to evaluate the model every n epochs instead of every epoch by passing the check_val_every_n_epoch flag to the trainer. Because of the way our data module is designed in data.py, epochs consist of forward and backward passes on fewer samples than the full dataset. This is a good idea for large datasets. A model can start to overfit within an epoch when an epoch consists of running forward and backward passes over the full dataset. For less frequent evaluation, set the check_val_every_n_epoch flag in the trainer. You can also limit how often PyTorch Lightning creates checkpoints (every_n_epochs in Lightning 1.4). Note that we use the ddp backend to train the model, but others are available and worth exploring for models with hundreds of millions of parameters such as ddp_sharded.

"""
Script: train.py
About:
Train the multi-node multi-GPU PyTorch Lightning training script on an AzureML cluster according to the parameters in the params.yml file.
"""
# Open imports
import os
import yaml
import glob
import torch
import pytorch_lightning as pl
import pytorch_lightning.callbacks as plc
from torch.utils.data import DataLoader
from environment import (
configure_multi_node_environment,
print_dl_library_versions,
)
# Project imports
from model import BoringModel
from data import PartitionPerEpochDataModule
def main():
# Load parameters from YAML config
with open("params.yml", "r") as stream:
PARAMS = yaml.safe_load(stream)
# Print DL library version details
print_dl_library_versions()
# Configure multi-node environment
if PARAMS["nodes"] > 1:
configure_multi_node_environment(nodes=PARAMS["nodes"])
# Seed everything
pl.seed_everything(42)
# Create checkpoint directory
if not os.path.exists(PARAMS["model_output_dir"]):
os.makedirs(PARAMS["model_output_dir"])
# Define checkpoint callback
checkpoint_callback = plc.ModelCheckpoint(
monitor="val_loss",
dirpath=PARAMS["model_output_dir"],
filename="{epoch:02d}-{val_loss:.5f}",
save_top_k=5,
mode="min",
)
# Create early stopping condition
early_stopping_callback = plc.early_stopping.EarlyStopping(
"val_loss", patience=5, mode="min"
)
# Build model
model = BoringModel(
num_features=PARAMS["num_features"],
num_classes=PARAMS["num_classes"]
)
# Define lightning model trainer
# Reloading the dataloaders every epoch is essential
# When switching the training files, as we do in data.py
trainer = pl.Trainer(
gpus=PARAMS["gpus"],
num_nodes=PARAMS["nodes"],
accelerator="ddp",
num_sanity_val_steps=0,
max_epochs=PARAMS["epochs"],
checkpoint_callback=True,
callbacks=[early_stopping_callback, checkpoint_callback],
reload_dataloaders_every_n_epochs=1,
)
# Build the data module
data = PartitionPerEpochDataModule(
train_files=glob.glob(PARAMS["train_files"] + "data*.pt"),
val_file=PARAMS["val_file"],
batch_size=PARAMS["batch_size"],
)
# Train the model
trainer.fit(model, datamodule=data)
if __name__ == "__main__":
main()

Section 6: Submit your training script and validate that your model properly trains

After registering your dataset, configuring all the files mentioned above, and activating your conda environment, it’s time to submit the training job. To do this, simply run submit.py.

python submit.py

AzureML will provide a URL for your workspace in the terminal where you run your job. Visit the URL to view log files and outputs associated with your training job. In our params.yml file, we specify a model checkpoint directory which we can monitor in the AzureML UI. Check that as your model trains, the best checkpoints are being logged according to the monitoring metric specified in train.py. Here we see that our boring model achieved its lowest validation loss at epoch 2.

Image 4: Examining model checkpoints

Conclusion

This blogpost provides a comprehensive working example of training a PyTorch Lightning model on an AzureML GPU cluster consisting of multiple nodes and multiple GPUs per node. It addresses many of the gaps in the existing documentation for these technologies, providing explicit direction that will help data scientists and ML engineers train deep learning models on large datasets using a GPU cluster. This is an example of some of the work we are doing at UnitedHealth Group to use artificial intelligence at scale to make the healthcare system more efficient, improve operations, and most importantly, enhance the quality of care we provide to patients.

References

The distributed training setup detailed in this blogpost is based on a combination of examples from PyTorch Lightning issue threads, this Medium blogpost, this nag blogpost, and AzureML’s PyTorch lightning documentation.

The core idea behind distributed data parallel training in PyTorch is presented in “PyTorch Distributed: Experiences on Accelerating Data Parallel Training” from Li et al.

Special Thanks

Special thanks to colleagues Galina Grunin and Andrew Plesniak for their assistance, to Dan Mulcahy and Sanji Fernando for reviewing and supporting this work, and to OptumLabs and UnitedHealth Group for permission to share this work publicly.

--

--

Joel Stremmel

ML Engineer and wannabe poet. Learning to listen, but it’s fun to shout.