Multi-GPU Training in PyTorch with Code (Part 4): Torchrun

Anthony Peng
Polo Club of Data Science | Georgia Tech
5 min readJul 7, 2023

We discussed single-GPU training in Part 1, multi-GPU training with DP in Part 2, and multi-GPU training with DDP in Part 3. We will explore how to resume training with Torchrun in this article. It’s critical that our training code should support fault-tolerant training. We may encounter OOM or other unexpected errors on one of the GPUs, and that can lead to retraining if no snapshot is saved during the training process. Torchrun handles the issue gracefully along with DDP. In this article, we will explore how to use Torchrun.

Part 1. Single GPU Example — Training ResNet34 on CIFAR10

Part2. Data Parallel — Training code & issue between DP and NVLink

Part3. Distributed Data Parallel — Training code & Analysis

Part4. Torchrun (this article)— Fault tolerance

Torchrun

  1. What is torchrun?

torchrun provides a superset of the functionality as torch.distributed.launch with the following additional functionalities:

Worker failures are handled gracefully by restarting all workers.

Worker RANK and WORLD_SIZE are assigned automatically.

Number of nodes is allowed to change between minimum and maximum sizes (elasticity).

For more details, please check the official doc.

DDP Initialization. With torchrun, we don’t have to set the environment variables.

def ddp_setup_torchrun():
init_process_group(backend="nccl")

TrainerDDPTorchrun. We need to save all necessary state_dict of the model, optimizer, and lr_scheduler. Note we can fetch gpu_id from a new environment variable “LOCAL_RANK”.

class TrainerDDPTorchrun(TrainerDDP):
def __init__(
self,
model: nn.Module,
trainloader: DataLoader,
testloader: DataLoader,
sampler_train: DistributedSampler,
) -> None:
self.gpu_id = int(os.environ["LOCAL_RANK"])
self.epochs_run = 0
super().__init__(self.gpu_id, model, trainloader, testloader, sampler_train)

def _save_snapshot(self, epoch: int):
snapshot = dict(
EPOCHS=epoch,
MODEL_STATE=self.model.state_dict(),
OPTIMIZER=self.optimizer.state_dict(),
LR_SCHEDULER=self.lr_scheduler.state_dict(),
)
model_path = self.const["trained_models"] / f"snapshot.pt"
torch.save(snapshot, model_path)

def _load_snapshot(self, path: str):
snapshot = torch.load(path, map_location="cpu")
self.epochs_run = snapshot["EPOCHS"] + 1
self.model.load_state_dict(snapshot["MODEL_STATE"])
self.optimizer.load_state_dict(snapshot["OPTIMIZER"])
self.lr_scheduler.load_state_dict(snapshot["LR_SCHEDULER"])
print(
f"[GPU{self.gpu_id}] Resuming training from snapshot at Epoch {snapshot['EPOCHS']}"
)

def train(self, max_epochs: int, snapshot_path: str):
if Path(snapshot_path).exists():
print("Loading snapshot")
self._load_snapshot(snapshot_path)

self.model.train()
for epoch in range(self.epochs_run, max_epochs):
# https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
self.sampler_train.set_epoch(epoch)

self._run_epoch(epoch)
# only save once on master gpu
if self.gpu_id == 0 and epoch % self.const["save_every"] == 0:
self._save_snapshot(epoch)
# save last epoch
self._save_checkpoint(max_epochs - 1)

Main function.

def main_ddp_torchrun(
snapshot_path: str,
final_model_path: str,
):
ddp_setup_torchrun()

const = prepare_const()
train_dataset, test_dataset = cifar_dataset(const["data_root"])
train_dataloader, test_dataloader, train_sampler = cifar_dataloader_ddp_uneven(
train_dataset, test_dataset, const["batch_size"]
)
model = cifar_model()
trainer = TrainerDDPTorchrun(
model=model,
trainloader=train_dataloader,
testloader=test_dataloader,
sampler_train=train_sampler,
)
trainer.train(const["total_epochs"], snapshot_path=snapshot_path)
trainer.test(final_model_path)

destroy_process_group() # clean up

Experiments

if __name__ == "__main__":
snapshot_path = Path("./trained_models/snapshot.pt")
final_model_path = Path("./trained_models/CIFAR10_ddp_epoch14.pt")
main_ddp_torchrun(snapshot_path, final_model_path)

Output. We intentionally interrupt the code after training for 6 epochs and resume training by loading the snapshot we have saved. Note we can also resume training with different gpu ids or even different total number of gpus.

$ CUDA_VISIBLE_DEVICES=4,5 torchrun --standalone --nproc_per_node=2 main.py
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
------------------------------------------------------------------------------------------
[GPU0] Epoch 0 | Batchsize: 128 | Steps: 196 | LR: 0.1000 | Loss: 2.2538 | Acc: 20.95%
------------------------------------------------------------------------------------------
[GPU1] Epoch 0 | Batchsize: 128 | Steps: 196 | LR: 0.1000 | Loss: 2.2456 | Acc: 20.95%
------------------------------------------------------------------------------------------
[GPU0] Epoch 1 | Batchsize: 128 | Steps: 196 | LR: 0.1000 | Loss: 1.5236 | Acc: 43.74%
------------------------------------------------------------------------------------------
[GPU1] Epoch 1 | Batchsize: 128 | Steps: 196 | LR: 0.1000 | Loss: 1.5287 | Acc: 43.74%
------------------------------------------------------------------------------------------
[GPU1] Epoch 2 | Batchsize: 128 | Steps: 196 | LR: 0.1000 | Loss: 1.2607 | Acc: 54.80%
------------------------------------------------------------------------------------------
[GPU0] Epoch 2 | Batchsize: 128 | Steps: 196 | LR: 0.1000 | Loss: 1.2567 | Acc: 54.80%
------------------------------------------------------------------------------------------
[GPU0] Epoch 3 | Batchsize: 128 | Steps: 196 | LR: 0.1000 | Loss: 1.0464 | Acc: 62.83%
------------------------------------------------------------------------------------------
[GPU1] Epoch 3 | Batchsize: 128 | Steps: 196 | LR: 0.1000 | Loss: 1.0456 | Acc: 62.83%
------------------------------------------------------------------------------------------
[GPU0] Epoch 4 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.8687 | Acc: 68.89%
------------------------------------------------------------------------------------------
[GPU1] Epoch 4 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.8847 | Acc: 68.89%
------------------------------------------------------------------------------------------
[GPU0] Epoch 5 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.6019 | Acc: 78.55%
------------------------------------------------------------------------------------------
[GPU1] Epoch 5 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.6197 | Acc: 78.55%
------------------------------------------------------------------------------------------
[GPU0] Epoch 6 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.5142 | Acc: 82.04%
------------------------------------------------------------------------------------------
[GPU1] Epoch 6 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.5194 | Acc: 82.04%
------------------------------------------------------------------------------------------
[GPU1] Epoch 7 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.4458 | Acc: 84.78%
------------------------------------------------------------------------------------------
[GPU0] Epoch 7 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.4485 | Acc: 84.78%
^CWARNING:torch.distributed.elastic.agent.server.api:Received 2 death signal, shutting down workers
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 1322493 closing signal SIGINT
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 1322494 closing signal SIGINT
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f0a4d538e50>
Traceback (most recent call last):

...

$ CUDA_VISIBLE_DEVICES=4,5 torchrun --standalone --nproc_per_node=2 main.py
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Loading snapshot
Loading snapshot
[GPU1] Resuming training from snapshot at Epoch 6
[GPU0] Resuming training from snapshot at Epoch 6
------------------------------------------------------------------------------------------
[GPU1] Epoch 7 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.4456 | Acc: 84.77%
------------------------------------------------------------------------------------------
[GPU0] Epoch 7 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.4486 | Acc: 84.77%
------------------------------------------------------------------------------------------
[GPU1] Epoch 8 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.3724 | Acc: 87.50%
------------------------------------------------------------------------------------------
[GPU0] Epoch 8 | Batchsize: 128 | Steps: 196 | LR: 0.0100 | Loss: 0.3707 | Acc: 87.50%
------------------------------------------------------------------------------------------
[GPU1] Epoch 9 | Batchsize: 128 | Steps: 196 | LR: 0.0010 | Loss: 0.2852 | Acc: 90.46%
------------------------------------------------------------------------------------------
[GPU0] Epoch 9 | Batchsize: 128 | Steps: 196 | LR: 0.0010 | Loss: 0.2898 | Acc: 90.46%
------------------------------------------------------------------------------------------
[GPU1] Epoch 10 | Batchsize: 128 | Steps: 196 | LR: 0.0010 | Loss: 0.1625 | Acc: 95.34%
------------------------------------------------------------------------------------------
[GPU0] Epoch 10 | Batchsize: 128 | Steps: 196 | LR: 0.0010 | Loss: 0.1610 | Acc: 95.34%
------------------------------------------------------------------------------------------
[GPU1] Epoch 11 | Batchsize: 128 | Steps: 196 | LR: 0.0010 | Loss: 0.1262 | Acc: 96.49%
------------------------------------------------------------------------------------------
[GPU0] Epoch 11 | Batchsize: 128 | Steps: 196 | LR: 0.0010 | Loss: 0.1341 | Acc: 96.49%
------------------------------------------------------------------------------------------
[GPU1] Epoch 12 | Batchsize: 128 | Steps: 196 | LR: 0.0010 | Loss: 0.1079 | Acc: 97.07%
------------------------------------------------------------------------------------------
[GPU0] Epoch 12 | Batchsize: 128 | Steps: 196 | LR: 0.0010 | Loss: 0.1128 | Acc: 97.07%
------------------------------------------------------------------------------------------
[GPU1] Epoch 13 | Batchsize: 128 | Steps: 196 | LR: 0.0010 | Loss: 0.0966 | Acc: 97.56%
------------------------------------------------------------------------------------------
[GPU0] Epoch 13 | Batchsize: 128 | Steps: 196 | LR: 0.0010 | Loss: 0.0951 | Acc: 97.56%
------------------------------------------------------------------------------------------
[GPU1] Epoch 14 | Batchsize: 128 | Steps: 196 | LR: 0.0001 | Loss: 0.0840 | Acc: 97.82%
------------------------------------------------------------------------------------------
[GPU0] Epoch 14 | Batchsize: 128 | Steps: 196 | LR: 0.0001 | Loss: 0.0856 | Acc: 97.82%
[GPU1] Test Acc: 73.0500%
[GPU0] Test Acc: 73.0500%

--

--