Published in


A Comprehensive Tutorial to Pytorch DistributedDataParallel

Photo by XPS on Unsplash

The limited computation resource at school discourages distibuted training across multiple gpus. I started to learn it for the first time when I joined Microsoft as an intern. It’s basically an easy job to wrap the model with DDP (short for DistributedDataParallel). What frustrated me was that I cannot properly adjust my workflow for multi-gpu, including DataLoader, Sampler, training and evaluating. The tutorials and blogs on Internet hardly includes all these stuff. After addressing so many bugs I came across, I’ve come up with the best practice so far.

In this blog, I want to share my code, my insighs with all beginners in DDP. I hope this blog will help them to avoid horrible bugs and mistakes. I’m not going to include detailed explanation of how DDP works, instead, I provide minimum knowledge needed to make the model run in multiple gpus. Note that I only introduce DDP on one machine with multiple gpus, which is the most general case (Otherwise, we should use model parallel as stated in the official blog). This blog is organized as:

BTW, I’m using torch==1.7.1, but I think it will work just fine in torch>=1.7.1.

Overview of DDP

First we must understand several terms used in distributed training:

  • master node: the main gpu responsible for synchronizations, making copies, loading models, writing logs;
  • process group: if you want to train/test the model over K gpus, then the K process forms a group, which is supported by a backend (pytorch managed that for you, according to the documentation, nccl is the most recommended backend);
  • rank: within the process group, each process is identified by its rank, from 0 to K-1;
  • world size: the number of processes in the group i.e. gpu number——K.

Pytorch provides two settings for distributed training: torch.nn.DataParallel (DP) and torch.nn.parallel.DistributedDataParallel (DDP), where the latter is officially recommended. In short, DDP is faster, more flexible than DP. The fundamental thing DDP does is to copy the model to multiple gpus, gather the gradients from them, average the gradients to update the model, then synchronize the model over all K processes. We can also gather/scatter tensors/objects other than gradients by torch.distributed.gather/scatter/reduce.

In case the model can fit on one gpu (it can be trained on one gpu with batch_size=1) and we want to train/test it on K gpus, the best practice of DDP is to copy the model onto the K gpus (the DDP class automatically does this for you) and split the dataloader to K non-overlapping groups to feed into K models respectively.

Now, things are clear to us. We have to do the following things:

  1. setup the process group, which is three lines of code and needs no modification;
  2. split the dataloader to each process in the group, which can be easily achieved by or any customized sampler;
  3. wrap our model with DDP, which is one line of code and barely needs modification;
  4. train/test our model, which is the same as is on 1 gpu;
  5. clean up the process groups (like free in C), which is one line of code.
  6. optional: gather extra data among processes (possibly needed for distributed testing), which is basically one line of code;

Very easy, right? In fact it is. Let’s do it step by step.

1. Setup the process group

Here it is, no extra steps.

import torch.distributed as distdef setup(rank, world_size):    os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)

2. Split the dataloader

We can easily split our dataloader by The sampler returns a iterator over indices, which are fed into dataloader to bachify.

The DistributedSampler split the total indices of the dataset into world_size parts, and evenly distributes them to the dataloader in each process without duplication.

from import DistributedSamplerdef prepare(rank, world_size, batch_size=32, pin_memory=False, num_workers=0):
dataset = Your_Dataset()
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)

dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)

return dataloader

Suppose K=3, and the length of dataset is 10. We must understand that DistributedSampler imposes even partition of indices.

  • If we set drop_last=False when defining DistributedSampler, it will automatically pad. For example, it splits indices [0,1,2,3,4,5,6,7,8,9] to [0,3,6,9] when rank=1, [0,4,7,0] when rank=2, and [2,5,8,0] when rank=3. As you can see, such padding may cause issues because the padded 0 is a data record.
  • Otherwise, it will strip off the trailing elements. For example, it splits the indices to [0,3,6] at rank=1, [1,4,7] at rank=2, and [2,5,8] at rank=3. In this case, it tailored 9 to make the indice number divisible by world_size.

It is very simple to customize our Sampler. We only need to create a class, then define its __iter__() and __len__() function. Refer to the official documentation for more details.

BTW, you’d better set the num_workers=0 when distributed training, because creating extra threads in the children processes may be problemistic. I also found pin_memory=False avoids many horrible bugs, maybe such things are machine-specific, please email me if you readers explored more details.

3. Wrap the model with DDP

We should first move our model to the specific gpu (recall that one model replica resides in one gpu), then we wrap it with DDP class. The following function takes in an argument rank, which we will introduce soon. For now, we just keep in mind rank equals the gpu id.

from torch.nn.parallel import DistributedDataParallel as DDPdef main(rank, world_size):
# setup the process groups
setup(rank, world_size)
# prepare the dataloader
dataloader = prepare(rank, world_size)

# instantiate the model(it's your own model) and move it to the right device
model = Model().to(rank)

# wrap the model with DDP
# device_ids tell DDP where is your model
# output_device tells DDP where to output, in our case, it is rank
# find_unused_parameters=True instructs DDP to find unused output of the forward() function of any module in the model
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True)

There are a few tricky things here:

  • When we want to access some customized attributes of the DDP wrapped model, we must reference model.module. That is to say, our model instance is saved as a module attribute of the DDP model. If we assign some attributes xxx other than built-in properties or functions, we must access them by
  • When we save the DDP model, our state_dict would add a module prefix to all parameters.
  • Consequently, if we want to load a DDP saved model to a non-DDP model, we have to manually strip the extra prefix. I provide my code below:
# in case we load a DDP model checkpoint to a non-DDP modelmodel_dict = OrderedDict()
pattern = re.compile('module.')
for k,v in state_dict.items():
if"module", k):
model_dict[re.sub(pattern, '', k)] = v
model_dict = state_dict

4. Train/test our model

This part is the key to implementing DDP. First we need to know the basis of multi-processing: all children processes together with the parent process run the same code.

In PyTorch, torch.multiprocessing provides convenient ways to create parallel processes. As the official documentation says,

The spawn function below addresses these concerns and takes care of error propagation, out of order termination, and will actively terminate processes upon detecting an error in one of them.

So, using spawn is a good choice.

In our script, we should define a train/test function before spawning it to parallel processes:

def main(rank, world_size):
# setup the process groups
setup(rank, world_size)
# prepare the dataloader
dataloader = prepare(rank, world_size)

# instantiate the model(it's your own model) and move it to the right device
model = Your_Model().to(rank)

# wrap the model with DDP
# device_ids tell DDP where is your model
# output_device tells DDP where to output, in our case, it is rank
# find_unused_parameters=True instructs DDP to find unused output of the forward() function of any module in the model
model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=True) #################### The above is defined previously

optimizer = Your_Optimizer()
loss_fn = Your_Loss()
for epoch in epochs:
# if we are using DistributedSampler, we have to tell it which epoch this is

for step, x in enumerate(dataloader):

pred = model(x)
label = x['label']

loss = loss_fn(pred, label)

This main function is run in every parallel process. We now need to call it by spawn method. In our .py script, we write:

import torch.multiprocessing as mp
if __name__ == '__main__':
# suppose we have 3 gpus
world_size = 3

Remember the first argument of main is rank? It is automatically passed to each process by mp.spawn , we don’t need to pass it explicitly. rank=0 is the master node by default. The rank ranges from 0 to K-1 (2 in our case).

5. Clean up the process groups

The last line of main function is the clean up function, which is:

def cleanup():

Bravo! We have completed the basic workflow of Distributed training/tesing!

6. Optional: Gather extra data among processes

Sometimes we need to collect some data from all processes, such as the testing result. We can easily gather tensors by dist.all_gather and objects by dist.all_gather_object.

Without loss of generality, I assume we want to collect python objects. The only constraint of the object is it must be serializable, which is basically everything in python. One should always assign torch.cuda.set_device(rank) before using all_gather_xxx. And, if we want to store a tensor in the object, it must locate at the output_device.

def main(rank, world_size):
data = {
'tensor': torch.ones(3,device=rank) + rank,
'list': [1,2,3] + rank,
'dict': {'rank':rank}

# we have to create enough room to store the collected objects
outputs = [None for _ in range(world_size)]
# the first argument is the collected lists, the second argument is the data unique in each process
dist.all_gather_object(outputs, data)
# we only want to operate on the collected objects at master node
if rank == 0:

Issues about dist.barrier()

The most confusing thing to me is when to use dist.barrier(). As the documentation says, it synchronizes processes. In other words, it blocks processes until all of them reaches the same line of code: dist.barrier(). I summarize its usage as follows:

  1. we do not need it when training, since DDP automatically does it for us (in loss.backward());
  2. we do not need it when gathering data, since dist.all_gather_object does it for us;
  3. we need it when enforcing execution order of codes, say one process loads the model that another process saves (I can hardly imagine this scenario is needed).

In this post, we learnt how to implement DDP in our models from scratch. Hopefully everyone read this could benefit. Thank you.



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store