PyTorch Data Parallel Best Practices on Google Cloud
Authors: Shen Li (Meta AI), Jessica Choi (Meta AI), Pavel Belevich (Meta AI), Yanli Zhao (Meta AI), Rohan Varma (Meta AI), Geeta Chauhan (Meta AI), Pritam Damania (Meta AI), Mahesh Yadav (Meta AI)
Collaborators: Vaibhav Singh (Google), Isaack Karanja (Google), Parvez Mulla (Google)
Distributed training techniques advanced rapidly in recent years. Many of those are available in the PyTorch distributed package, including DistributedDataParallel (DDP), PipelineParallel (Pipe) and FullyShardedDataParallel (FSDP). They collectively provide a rich set of tools to facilitate machine learning applications in distributed environments. Although in principle these features target different scenarios, it is not immediately clear which one is the best option for specific hardware and model configurations.
This post uses an extensive set of experiments on Google Cloud to quantify the performance of DDP, Pipe and FSDP with different model sizes, batch sizes, cluster sizes and network bandwidths. We hope this post can help users make design decisions on their distributed training solutions.
- When ACO is high (either due to slow network or small batch size), PipelineParallel + DistributedDataParallel achieves the highest throughput for models with less than 10B parameters. Compared to DistributedDataParallel, the speedup comes from a smaller AllReduce ring and concurrent AllReduces on two exclusive sets of devices.
- When ACO is low (either due to fast network or large batch size), DistributedDataParallel attains the highest throughput for models with less than 3B parameters. For our experiments, the boundary between high and low is around 20ms/sample.
- For the remaining cases, FullyShardedDataParallel is the best option. As of v1.11, it can scale to 1T-parameter models.
Experiment Setup
All experiments presented in this post are conducted on Google Cloud Platform. We used Terraform to deploy a 64-node Slurm cluster of A2 machines, where each A2 machine is equipped with 8 A100 GPUs (40GB memory). Within a single machine, GPUs communicate through a third-generation NVLink with 600GB/s bandwidth (4800 Gbps). Different machines are connected with 100 Gbps Ethernet. For experiments focusing on network impact, we also enabled GCP’s proprietary NCCL Fast Socket, which conditionally improves the performance of collective communication through Ethernet. We used CUDA 11.4 with driver version 470.57.02.
Overview and Terminology
Features
In this post, we will evaluate three main data-parallel paradigms for distributed training:
- DDP uses torch.nn.parallel.DistributedDataParallel to wrap the entire model, which replicates all model parameters to every device and runs AllReduce during the backward pass to synchronize gradients across model replicas. Please refer to the DDP tutorial and DDP paper for more details.
- PDP partitions the model into an nn.Sequential, enables pipeline parallelism using torch.distributed.pipeline.sync.Pipe and wraps each Pipe instance with torch.nn.parallel.DistributedDataParallel for data parallelism. Please refer to the PyTorch Pipe tutorial, GPipe and TorchGPipe papers for more details.
- FSDP wraps sub-modules into torch.distributed._fsdp.FullyShardedDataParallel units. Each unit shards model parameters and scatters shards onto data-parallel processes. It then runs AllGather to collect the full parameter within that unit before the forward and backward computations and discards collected shards afterward. Finally, it launches ReduceScatter to synchronize gradients. The implementation was upstreamed from FairScale FSDP and is available as a prototype feature in PyTorch 1.11. Please refer to ZeRO and Sharded DataParallel papers for more details.
Besides these 3 main features, experiments also cover generic memory-saving techniques as shown in Figure 2, including activation checkpointing (cp), activation offloading (ao), parameter offloading (po) and micro-batching. These techniques will be presented and discussed close to where they are involved.
Experiments
Experiments are organized into 4 themes:
- Scaling Efficiency measures how distributed training paradigms scale from small to large GPU clusters
- Model Scale Limit tests what’s the largest model that can fit into a given number of GPUs using different paradigms
- Impact of Network Bandwidth evaluates training speedup obtained by increasing network bandwidth
- Best Practices provides guidelines on what features to use for specific model size, batch size and network bandwidth combinations
Implementation
Models are implemented based on minGPT. We scale the models from 162M-parameters (GPTSmall) all the way to 1T-parameters. To accelerate model initialization, we also integrated the GPT model with the PyTorch “meta” device, which only creates shape and dtype fields without allocating real tensor storage. More specifically, models are initially created on the meta device. Then, PDP uses the shape and type information of the meta tensor to estimate model size, and automatically partitions and balances the layers into stages across specified devices. PyTorch does not provide native auto-partitioning algorithms yet, but it’s not too hard to implement one for Transformer-based models. The code snippet below shows how we implemented auto-partitioning for our experiments. It does not guarantee optimal partitioning, but serves our purpose.
FSDP wraps each transformer layer into an FSDP unit when the model size is below 81B. For larger models, FSDP wraps individual Linear layers within the transformers into FSDP units to reduce peak memory consumption. PDP partitions and FSDP shards are directly materialized on the destination device, avoiding the overhead of initializing the model on the CPU and copying to the GPU. All experiments use fp16 tensors as parameters and random integer tensors as dummy input data, where vocabulary size is 50K and block size is 256. All experiments use SGD optimizer.
Scaling Efficiency
Varying World Size
We first measure the scaling efficiency in terms of global throughput (1k Tokens per second) by varying the number of GPUs (1–128). Throughput is calculated as Block Size x Batch Size x World Size. In this section, experiments always use 8 GPUs per machine with different numbers of machines, except when the total number of GPUs is less than 8. Models cover GPTSmall, GPTLarge and GPT2.7B, as 2.7B parameters is roughly the largest model that DDP can handle using 40GB A100 GPUs. Given that models and batch sizes are both relatively small, we only employ vanilla DDP, PDP and FSDP without any activation memory optimizations or parameter offloading techniques, because those techniques would usually sacrifice throughput to reduce memory footprint. For PDP experiments, each pipeline spans 2 devices and divides each mini-batch into 2 micro-batches. In other words, given the same number of GPUs, the world size of PDP experiments is 1/2 compared to DDP and FSDP experiments. Hence, to maintain the same per-GPU batch size, PDP sets batch sizes to 16 and 40 respectively. Although the local pipeline can span to all GPUs on the same machine, using more GPUs within one pipeline incurs higher device-to-device (D2D) communication overhead. It usually only makes more sense to use longer pipelines to accommodate larger models. We will evaluate 4-device and 8-device pipelines in the “Model Scale Limit” section.
Results are presented in Figure 3–5. Bars marked with ddp8/ddp20 mean DDP with batch size 8 and 20 respectively. The dashed and solid gray lines marked as opt8 and opt20 are the expected throughput with perfect scaling (i.e., using 10 GPU attains 10X throughput compared to 1 GPU). All experiments show a considerable throughput dip when switching from 8 GPUs to 16 GPUs. This is because the cross-machine communication needs to travel through 100 Gbps Ethernet for 16 GPUs, which is a lot slower than the intra-machine NVLink (600GB/s, i.e., 4800 Gbps) interconnect for a single machine with 8 GPUs. Due to this reason, using more GPUs does not always lead to higher throughput, as network bandwidth also plays an important role.
When using a single machine, DDP attains the highest throughput across the 3 model configurations, which is expected as DDP incurs minimum cross-device communication volume. However, for multi-machine experiments, PDP outperforms DDP and the gap slightly widens with the increase of model size. The speedup comes from 2 sources:
- Compared to DDP, PDP breaks one large AllReduce ring into 2 smaller ones, which speeds up individual communication operations. The profiling results below show that PDP’s AllReduce (before overlapping with another AllReduce) is about 6% faster than the same AllReduce in DDP.
- As PDP breaks the devices into 2 smaller and disjoint sets, AllReduce can concurrently and safely run on these 2 sets. When AllReduce overlap occurs, each PDP AllReduce takes roughly 25ms while DDP AllReduce takes at least 17ms, i.e., it leads to 1–25 / (17 * 2) = 26% speedup. This also implies that running a single AllReduce at a time cannot saturate the bandwidth by default.
Varying Batch Size
The above experiments also show that using batch size 20 consistently outperforms batch size 8 in terms of throughput. This is expected, because a large batch size means smaller amortized per-sample communication overhead (ACO). However, different batch sizes might lead to different model accuracy and activation memory consumption, and hence it is not always arbitrarily configurable. Therefore, instead of suggesting the optimal batch size, we measure the impacts of batch size on throughput and peak memory consumption for different distributed training techniques, and present the result to help you pick the optimal combination for your applications.
For all figures below, the x-axis is the per-GPU batch size. Since the pipeline spans 2 devices, the per-pipeline input batch size is 2X larger to maintain the same per-GPU batch size. The legends pdp2 and pdp4 refer to PDP with pipeline chunk sizes 2 and 4 respectively. In this set of experiments, we also vary pipeline chunk size, i.e., the number of micro-batches within each mini-batch.
From the results, the first observation is that throughput steadily increases when using larger batch sizes until approaching the largest possible batch size. If we read both throughput and peak memory together, it’s clear that such throughput regression usually occurs when the program uses more than 32GB GPU memory. This is because, when operating near the GPU memory capacity, PyTorch CUDACachingAllocator might need to frequently launch defragmentation procedures to accommodate new allocations. Under the hood, it frees cached memory blocks and invokes cudaMalloc, which are expensive operations.
Another observation is that, although increasing pipeline chunk size slightly hurts throughput, it significantly reduces the peak memory consumption. This is because using larger chunk sizes (i.e., smaller micro-batch sizes), on the one hand, leads to more CUDA kernels with smaller inputs, which incurs higher scheduling overhead. On the other hand, smaller micro-batches help to reduce the size of the temporary tensors allocated by the autograd engine, which in turn reduces the peak memory consumption.
Model Scale Limit
One major challenge of training large models is fitting model parameters and activations into limited GPU memory. With 40GB A100 GPUs, it will be hard for vanilla DDP to scale beyond 3B-parameter models. Pipeline supports larger models by partitioning the model and placing smaller partitions on different devices, which reduces per-device parameter memory consumption and activation memory consumption. FSDP shards individual layers and scatters shards across data-parallel processes, which saves per-device parameter memory. Besides these distributed training paradigms, there are also multiple generic memory optimization techniques that apply to almost all use cases.
- Activation Checkpointing (ac) usually organizes the model into a sequence of stages, saves the outer activation at stage boundaries, and discards inner activations within a stage during the forward pass. In the backward pass, it recomputes inner activations one stage at a time before conducting backward computations, which avoids materializing the full activations of the entire forward pass.
- Activation Offloading (ao) offloads activation to CPU memory during the forward pass, and loads it back to GPU on demand during the backward pass. This technique can be combined with Activation Checkpointing to offload outer activations.
- Parameter Offloading (po) offloads parameters to CPU memory and loads them back to GPU on demand during the forward and backward passes.
- Micro-Batching divides each mini-batch into smaller micro-batches, and runs forward/backward on micro-batches one at a time. It does not save activation memories, but can reduce the size of temporary tensors created by the autograd engine.
Although, technically, the above 4 memory optimization techniques can work with DDP, PDP and FSDP, PyTorch only natively supports a subset of the combinations as of v1.11. Figure 2 describes the current status.
Among these 4 memory optimization techniques, activation checkpointing (cp) and activation offloading (ol) are already (partially) compatible with all 3 distributed training paradigms. Micro-batching is implemented for Pipeline natively and DDP (through the no_sync() API). Parameter offloading is only available for FSDP. Therefore, we measure these combinations accordingly.
Experiments cover models ranging from 162M parameters all the way to 1T. All experiments in this section use 32 GPUs on 4 machines and set batch size to 16. Only FSDP can scale to 1-trillion parameter models, but each iteration takes excessively long (4085 seconds) on a 100 Gbps network even with batch size 1, where communication delay dominates. We excluded 1T numbers from the plots below, as such a low throughput does not serve practical purposes. Such large models require either a faster network (400 Gbps+) or aggressive communication compressions and optimizations. We organized the remaining results for 10 different model configurations into 2 sets of figures. Figures 10–11 present the throughput and peak GPU memory for small- to medium-sized models. DDP hits OOM error beyond 2.8B models.
Figures 12–13 show medium-to-large models. PDP with 2-device pipeline hits OOM on the 16B parameter model. PDP with 4-device pipeline hits OOM on the 34B parameter model, which is about the size limit for the PDP paradigm using single-machine pipelines.
Impact of Network Bandwidth
Network bandwidth plays a crucial role in distributed training efficiency. To quantify the impact of network bandwidth, we conducted 8-GPU experiments on a single 8-GPU machine and across 4 machines (using only 2 GPUs per machine). Intra-machine communication through NVLink has a bandwidth of 600GB/s and inter-machine communication through Ethernet has a bandwidth of 100GB/s. The reason for using 4 machines instead of 8 machines is because PyTorch only supports single machine pipeline parallelism as of v1.11, and it requires at least 2 devices on the same machine to form a pipeline. As discussed above, batch size also affects throughput. So, we repeat the same set of experiments using batch sizes of 8 and 64, and then we measure the speedup ratio attained by faster networks (NVLink).
More specifically, the ratio is calculated as Per − Iteration Delay with Ethernet / Per − Iteration Delay with NVLink. Results are plotted in the 2 figures below.
In the first figure with small batch size, the speedup ratio increases with the model size. This aligns with our expectations, since large model size means higher communication overhead for all three distributed training paradigms. PDP sees the least speedup, which conforms with previous observations that PDP is the most communication-efficient solution in multi-node experiments.
One interesting observation is that the speedup for FSDP drops significantly on 13B or larger models. Those experiments use parameter offloading and finer FSDP wrapping granularity, which reduces GPU memory consumption but introduces more device/host communication overhead. In other words, the cross-machine communication overhead is less dominant. The traces below confirm the reasoning, where Host-to-Device (H2D) and Device-to-Host (D2H) become the bottleneck.
We then try a larger batch size (64), i.e., smaller amortized per-sample communication overhead. As a result, the impact of faster networks also falls.
GCP offers a proprietary network transport called NCCL Fast Socket to speed up collective communications through Ethernet. Fast Socket improves NCCL performance by reducing the contention between multiple TCP connections, and is purported to improve AllReduce throughput by 30–60%. This is helpful when using GPUs across multiple machines, since the 100 Gbps network is much slower than the 4800 Gbps NVLink. We measured its impact on per-iteration delay by comparing performance when using Fast Socket and without Fast Socket for 8 GPUs on 4 machines (2 GPUs each). Results in Figures 18–19 show that it clearly helps communication-bounded FSDP in almost all scenarios with up to 3X speedup. For DDP, sometimes, it slows down training, especially on small models where communication is not the bottleneck. For PDP, it almost never helped. We believe this is because PDP has already well-utilized network bandwidth by concurrently launching multiple AllReduce operations, and the additional complexity introduced by FastSocket leads to a negative impact on speed. Given that Fast Socket is only available on GCP and does not always speed up training, experiments presented in all other sections ran without Fast Socket.
Best Practices
Previous sections have explored combinations of various distributed training paradigms, memory optimization techniques and model/strategy configurations. From ML practitioners’ perspective, it would be useful to know what is the best strategy and configuration for a specific model, network, and batch size combination. This section summarizes the experiments and visualizes the recommendations. We repeat the same set of experiments using 8 GPUs on slower networks (100 Gbps Ethernet) and faster networks (600GB/s NVLink). For each batch size and model size combination, we pick the winner strategy with the highest throughput. Results are shown in the following two figures. Note that with both slow and fast networks, using 256 batch size with 13B+ models leads to CUDA OOM error, so the corresponding boxes in the figures are left blank. To conclude, for small- to medium-sized models, if the amortized communication overhead (ACO) is high, PDP attains the best throughput. Otherwise if ACO is low, DDP should be the best option. For large models with above 10B parameters, FSDP is optimal.
Acknowledgments
We would like to thank Google Cloud Platform experts Vaibhav Singh, Isaack Karanja, Parvez Mulla, Lori Baker and Sanjay Jacob for supporting our experiments, thank our colleagues Murray Kucherawy, Can Balioglu, Anjali Sridhar, Min Xu, Tingting Markstrum and Naman Goyal for building features that enabled this work, thank Susan Zhang, Hongyi Jia, Mingzhe Li, Mark Saroufim, Dmytro Dzhulgakov, Alisson Gusatti Azzolini, Myle Ott, Yaroslav Bulatov, Hamid Shojanazeri, Jongsoo Park, Liang Luo and Will Feng for technical discussions and feedback, and thank Carlos Escapa, Dwarak Rajagopal, Bernard Nguyen, Lin Qiao, Donny Greenberg, Brian O’Horo, Chris Gottbrath and Priya Sethuraman for supporting this project.