PyTorch Distributed Data Parallel Training

Image for post
Image for post

There are many ways to do distributed training for your neural network, such as using Horovod, BytePS, or the distributed package in PyTorch. This post looks into how does PyTorch implement the distributed package.

How to launch distributed data parallel training in PyTorch?

Assume that there is an application using data parallel to train the network in a single node. All you need to do is to modify the code:

  • to initialize a process group for communication
  • wrap the network in distributed data parallel instead of data parallel
  • use a distributed data sampler to split the data into different ranks

Refer PyTorch documentation for the complete example. This post focuses on the implementation.

Initialize Process Group

In the distributed training scenario, there are usually several processes. The whole processes that participated in the training is called a world. And each process is associated with a rank. In the latest PyTorch version, DistributedDataParallel recommends 1:1 mapping between the processes and GPU devices. But there is another mode which is 1:N mapping.

Back to the process group, let’s take a look at the init_process_group function:

Each participating process calls rendezvous function to find each other and establish a connection for communication. There are three types of init methods:

  • file: Use a shared file to implement a store.
  • tcp: Use tcp as communication method to implement a store.
  • env: It’s the same as tcp method, except that the address and the port are set in environment variables. Using python -m torch.distributed.launch utility to launch the processes will set environment variables automatically.

A store is actually a key value database, as you can see the interface:

So when a TCPStore is initialized, it tries to connect to the server, which is the rank 0 process. The server listens on MASTER_ADDR:MASTER_PORT in a daemon thread. The rank 0 process blocks at initializing the TCPStore until all the ranks (including rank 0) are connected to the server, while the other ranks will continue as long as they have connected to the server. However, the other ranks will block at the broadcast function until all the ranks have connected to the server. Then the rank 0 process creates an NCCL unique id and broadcast to the other ranks. After which, all the ranks could use the unique id to establish NCCL communication.

After a store is created, all the processes continue to create a default process group by calling _new_process_group_helper based on the backend. AProcessGroupNCCL will be created if the backend is nccl. Basically a process group is a binding to the backend, which provides collective algorithms such as allreduce, broadcast, etc. Nvidia NCCL library is used under nccl backend.

The DistributedDataParallel module

Many deep learning frameworks provide data parallel training, in which a model is trained by iterating over the data in mini-batches. In distributed data parallel training, a mini-batch is split into chunks, and each GPU will consume one chunk during an iteration. The split is done by a distributed sampler, which is not covered in this post.

Now let’s take a look at the DistributedDataParallel module. There are two ways to use DistributedDataParallel class:

  • Single-Process Multi-GPU (SPMG, 1:N mapping)
  • Single-Process Single-GPU (SPSG, 1:1 mapping)

From PyTorch 1.5 on, SPSG is recommended.

Inside the constructor, it first broadcasts the module states to other ranks, so that all of the ranks have the same states to start with. But what are module states here? The module states include module parameters and module buffers. Unlike parameters, a tensor in module buffers can not be learned. For example, in BatchNorm module, the running_mean will is a tensor in buffer instead of parameters.

More actions are done in _ddp_init_helper function:

The _compute_bucket_assignment_by_size function returns a list of bucket assignments. Then a Reducer is initialized. As in backward propagation, the order is from the output layer back to the input layer, so the bucket indices are in reversed order.

In my opinion, partitioning the parameters to buckets does not make the communication more efficient. Unlike Horovod implementation, which fuses small tensors before doing the communication, the ProcessGroupNCCL does not fuse the tensors in one bucket. Fusion might be better as it can increase the throughput of the communication.

However, partition would make the computation more efficient. The parameter’s gradients need to be synchronized after calculated in the local GPU. But it does not have to wait for the synchronization to be finished when they continue to calculate the other parameter’s gradient. It’s because that the gradient back propagation only needs the local gradient (because the output of a local rank in forward pass is calculated using the local input only). Thus the communication and computation can be overlapped, which speeds up the training.

The order of allreduce operation on the buckets is not the same as the order of the buckets ready. So if bucket 1 is ready before bucket 0, it still has to wait for bucket 0 to be ready and kick off the allreduce operation. If the reverse order thing is not the case, then the overlap between communication and computation would be small, thus hurts the performance.

The autograd_hook in Reducer

In side reducer.cpp file, the Reducer registers a hook for each parameter:

So when the local gradient is ready for a parameter, the autograd_hook will be called. The corresponding parameter will be marked as ready for reduction inside the hook. If all the parameters in a bucket are ready, then it is ready for reduction. The hook kicks off the allreduce operations until hitting a bucket that is not ready.

In PyTorch cpp code, we could get the grad_accumulator of a parameter easily. However, the api is not exported to python. What should we do if we want to implement the same mechanism ourselves in python? Nvidia's Apex library did a trick in its own implementation of DistributedDataParallel class:

In the computation graph, there is a AccumulateGrad node pointing to the grad field of a leaf tensor (A leaf tensor is a tensor that created by user or intialized by an initializer). The output of AccumulateGrad node will be filled into the grad field. Actually grad_accumulator function is to get the AccumulateGrad node, but grad_accumulator function is not available in python. So Apex called expand_as to make a new tensor, and the new tensor's grad_fn has an entry pointing to the AccumulateGrad of the original tensor.

Written by

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store