Distributed Machine Learning Training (Part 1 — Data Parallelism)

Jonathan Nguyen
10 min readMar 25, 2024

--

With the ever-increasing size and complexity of datasets, the need for efficient and scalable machine learning models has never been greater. Distributed machine learning training is a powerful technique that allows us to train models on large datasets by distributing the workload across multiple machines. By doing so, we can significantly reduce training times, enabling us to build and iterate on models more quickly than ever before. In this blog post, I’ll be exploring the key concepts and techniques behind distributed machine learning training, as well as its benefits and challenges.

Whether you’re a seasoned machine learning practitioner or just starting out, this article is designed to provide a comprehensive overview of distributed machine learning training, and how it can help you build more powerful and efficient models. So, let’s dive in!

Why do we need distributed machine learning?

Over the recent years, data has grown drastically in size, especially with the trending of LLM, with models typically requiring at least 1000 GB of text data for training. This is due to the fact that LLMs are trained on huge text corpora, with billions of parameters, and require substantial computational resources and specialized expertise to train effectively. For example, training GPT-3, a previous-generation model with 175 billion parameters, would take 288 years on a single NVIDIA V100 GPU.

Single-node training is too slow

The vanilla model training process is to load both the training data and ML model into the same accelerator (for example, a GPU), which is called single-node training.

There are two kinds of bandwidth:

  • Data Loading bandwidth
  • Model training bandwidth

Nowadays, we have more and more input data. Hence, we would want the data loading bandwidth to be as large as possible. However, due to the limited on-device memory of the GPUs or other accelerators, the real model training bandwidth is also limited. So, the mismatch between data loading bandwidth and model training bandwidth in single-node training is the bottleneck in our training pipeline.

If we can match data loading bandwidth and model training bandwidth in single-node training, it is unnecessary to conduct in-parallel model training since distributed data processing will always introduce control overheads.

Data parallelism

Simplified workflow of data parallel training

The main difference between single-node training and data parallel training is that we split the data loading bandwidth between multiple workers / GPUs. Therefore, for each GPU involved in the data parallel training job, the difference between its local data loading bandwidth and model training bandwidth is much smaller compared to the single-node case.

After each GPU receives its local batch of augmented input data, it will conduct local model training and validation. Here we see some difficulties:

  1. Different accelerators (GPUs) are trained on different batches of input data. Consequently, none of the GPUs can see the full training data. Thus traditional gradient descent optimization cannot be applied here.
  2. At the end of the day, we only want one training model while we have multiple different trained model weights with different GPUs nodes. To force all the workers to have the same view of the model parameters, we have to have an additional step called model synchronization. Model synchronization is about collecting and aggregating local gradients that have been generated by different nodes.
Data parallelism with Model Synchronization

Model synchronization

  1. Collects and sums up all the gradients from all the GPUs in use, as shown here:

2. Broadcasts the aggregated gradients to all the GPUs

Once the model synchronization steps have been completed, we can get the aggregated gradients locally on each GPU. Then, we can use these aggregated gradients for the model updates, which guarantees that the updated model parameters remain the same after this first data parallel training iteration. Similarly, in the following training iterations, we conduct model synchronization after each GPU generates its local gradients. So, model synchronization guarantees that the model parameters remain the same after every training iteration in a particular data parallel training job.

To guarantee model consistency, two methodologies can be applied:

  • Parameter server (centralized method): we can keep the model parameters in one place (a centralized node). Whenever a GPU/node needs to conduct model training, it pulls the parameters from the centralized node, trains the model, then pushes back model updates to the centralized node. Model consistency is guaranteed since all the GPUs/nodes are pulling from the same centralized node.
  • All-Reduce (decentralized method): every GPU/node keeps a copy of the model parameters so we force the model copies to synchronize periodically. Each GPU trains its local model replica using its own training data partition. After each training iteration, the model replicas that are held on different GPUs can be different since they are trained with different input data. Therefore, we inject a global synchronization step after each training iteration. This averages the parameters that are held on different GPUs so that model consistency can be guaranteed in this fully distributed manner.

Parameter Server

The parameter server architecture mainly consists of two roles: parameter server and worker. The parameter server can be regarded as the master node in the traditional Master/Worker architecture.

  • Master node (Parameter Server): is responsible for aggregating model updates from all workers and updating the model parameters held on the parameter server.
  • Workers are the computer nodes or GPUs that are responsible for model training. We split the total training data among all the workers. Each worker trains their local model with the training data partition that’s been assigned to it.
Step by step Parameter Server Architecture with single server node
  1. Pull Weights: All the workers pull the model parameters/weights from the centralized parameter server.
  2. Push Gradients: Each worker trains its local model with its local training data partition and generates local gradients. Then, all the workers push their local gradients to the centralized parameter server.
  3. Aggregate Gradients: After collecting all the gradients that have been sent from the worker nodes, the parameter server will sum up all the gradients.
  4. Model Update: Once the aggregated gradients have been calculated, the parameter server uses the aggregated gradients to update the model parameters on this centralized server.

Communication Bottleneck

1. Fan-out Weights Pulling: Looking at the communication pattern of pulling the weights from the parameter server to all the workers. As shown in the preceding diagram, this is a one-to-all communication where the centralized parameter server needs to send out model weights simultaneously to all the worker nodes.

Assuming that the communication bandwidth of each node is 1 and N workers in this training job. Since the centralized parameter server needs to send the model to N workers concurrently, the sending bandwidth (BW) to each worker is only 1/N. On the other hand, the receiving bandwidth for each worker is 1, which is much larger than the parameter server’s sending bandwidth of 1/N. Therefore, during the pulling weights stage, we have a communication bottleneck on the parameter server side.

2. Fan-in Gradients Pushing: Now, let’s look at the communication pattern during the gradient pushing process. As shown in the following diagram, during this process, all the GPUs concurrently send their local gradients to the centralized parameter server.

Given N workers in the parameter server architecture, each worker can send its local gradients with a sending bandwidth of 1. However, since the parameter server needs to receive gradients from all the workers at the same time, the receiving bandwidth for each worker is just 1/N. Therefore, the communication bottleneck is still on the parameter server’s side during the pushing gradients stage.

Sharding the model among parameter servers

We solve the issue of previous communication bottleneck by using load balancing:

Instead of having one parameter server, we split the model into N parameter servers, where each server is responsible for updating the model of 1/N model parameters

Sharded Parameter Servers

Sample code of ParameterServer and Worker:

import torch
import torch.nn as nn

class ParameterServer(nn.Module):
def __init__(self):
super().__init__()
self.model = Model()

if torch.cuda.is_available():
self.input_device = torch.device("cuda:0")
else:
self.input_device = torch.device("cpu")

self.optimizer = optim.SGD(self.model.parameters(), lr = 0.05)

def get_weights(self):
return self.model.state_dict()

def update_model(self, grads):
for para, grad in zip(self.model.parameters(), grads):
para.grad = grad
self.optimizer.step()
self.optimizer.zero_grad()


class Worker(nn.Module):
def __init__(self):
super().__init__()
self.model = MyNet()
if torch.cuda.is_available():
self.input_device = torch.device("cuda:0")
else:
self.input_device = torch.device("cpu")

def pull_weights(self, model_params):
self.model.load_state_dict(model_params)

def push_gradients(self, batch_idx, data, target):
data, target = data.to(self.input_device), target.to(self.input_device)
output = self.model(data)
data.requires_grad = True
loss = F.nll_loss(output, target)
loss.backward()
grads = []
for layer in self.parameters():
grad = layer.grad
grads.append(grad)
print(f"batch {batch_idx} training :: loss {loss.item()}")
return grads


from torchvision import datasets, transforms

train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=True,
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=False,
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))])),
batch_size=128, shuffle=True)

def main():
ps = ParameterServer()
worker = Worker()

for batch_idx, (data, target) in enumerate(train_loader):
params = ps.get_weights()
worker.pull_weights(params)
grads = worker.push_gradients(batch_idx, data, target)
ps.update_model(grads)
print("Done Training")

All-Reduce Architecture

In the All-Reduce architecture, we abandon the parameter server role in the parameter server architecture. Now, every node is equivalent and all of them are worker nodes.

The All-Reduce paradigm is borrowed from the traditional Message Passing Interface (MPI) domain. The Message Passing Interface (MPI) is an Application Program Interface that defines a model of parallel computing where each parallel process has its own local memory, and data must be explicitly shared by passing messages between processes.

Reduce

Reduce Operator with 3 workers

The Reduce operator (sum or averaging or multiplication) is used to aggregate the values from different nodes and store them in a single node.

At the beginning, Worker 1 has a value gradient of a, Worker 2 has a value gradient of b and Worker 3 has a value of c. After using the Reduce function, Worker 1 maintains the value gradient of a+b+c instead of a, the values held on Worker 2 and Worker 3 do not change.

All-Reduce

All-Reduce Operator with 3 Workers

All-Reduce allows all the nodes to get the same aggregated value. As we can see, before the All-Reduce function is used, Worker 1 holds a value of a, Worker 2 holds a value of b, and Worker 3 holds a value of c. Now, every worker will get the aggregated value once the All-Reduce operation has been performed.

The All-Reduce function allows all the workers to get the aggregated gradients from all the worker nodes. This gradient aggregation is the model synchronization procedure in the All-Reduce architecture. It guarantees that all the workers are using the same gradient to update the model in the current training iteration.

Ring All-Reduce

Ring All-Reduce has been widely adopted in deep learning frameworks such as PyTorch Distributed and TensorFlow.

Several popular implementations of Ring All-Reduce are as follows:

  • NVIDIA NCCL
  • Uber Horovod
  • Facebook Gloo
Ring All-Reduce with 3 Workers

Step 1: Worker 1 has a value of a, Worker 2 has a value of b, and Worker 3 has a value of c

Step 2: Worker 1 has a value of a. Worker 1 passes this value, a, to Worker 2. Worker 2 gets a+b. Worker 3 still has a value of c

Step 3: Worker 1 has a value of a. Worker 2 has a value of a+b, which it passes to Worker 3. Worker 3 now has a value of a+b+c

Step 4: Worker 3 passes a+b+c to Worker 1. Worker 1 now has a+b+c. Worker 2 now has a+b. Worker 3 now has a+b+c

Step 5: Worker 1, who has a+b+c, passes a+b+c to Worker 2. Worker 2 now has a+b+c and Worker 3 has a+b+c as well.

Sample code for All-Reduce Parallel training in a single machine with multiple GPUs.

We simply use torch.nn.DataParallel for parallel training with Pytorch. Pytorch will do the data parallel training under the hood. It will launch a single process with multiple threads. Each thread is responsible for running training tasks on a single GPU.

Data partition and Model synchronization in nn.DataParallel
  1. First, we initialize the model on Worker 1 and let Worker 1 split the input training data.
  2. Worker 1 send different input data partitions and will broadcast the model parameters to all the other workers (Worker 2 and Worker 3). Then they will start the data parallel training on all the workers.
  3. After each worker generates its local gradients, they will send their local gradients to Worker 1. After Worker 1 aggregates all the gradients from all the workers as gradients_sum, Worker 1 will broadcast gradients_sum to all the other workers.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torch import optim

class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
if torch.cuda.is_available():
device = torch.device(f"cuda")
else:
device = torch.device("cpu")
self.conv1 = nn.Conv2d(1,32,3,1).to(device)
self.dropout1 = nn.Dropout2d(0.5).to(device)
self.conv2 = nn.Conv2d(32,64,3,1).to(device)
self.dropout2 = nn.Dropout2d(0.75).to(device)
self.fc1 = nn.Linear(9216, 128).to(device)
self.fc2 = nn.Linear(128,20).to(device)
self.fc3 = nn.Linear(20,10).to(device)

def forward(self, x):
x = self.conv1(x)
x = self.dropout1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.dropout2(x)
x = F.max_pool2d(x,2)
x = torch.flatten(x,1)

x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)

output = F.log_softmax(x, dim = 1)
return output

train_set = datasets.MNIST('./mnist_data', download=True, train=True,
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))]))

test_set = datasets.MNIST('./mnist_data', download=True, train=False,
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,),(0.3081,))]))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = DataLoader(train_set, batch_size = 128,
shuffle=True, pin_memory = True)

train_epoch = 2

def main():
model = MyNet()
print("Using ", torch.cuda.device_count(), "GPUs for data parallel training")
optimizer = torch.optim.SGD(model.parameters(), lr = 5e-4)
model = nn.DataParallel(model)
model.to(device)
#Training
for epoch in range(train_epoch):
print(f"Epoch {epoch}")
for idx, (data, target) in enumerate(train_loader):
data, target = data.cuda(), target.cuda()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
print(f"batch {idx}, loss {loss.item()}")
print("Training Done!")

Parallel training in multiple machines with multiple GPUs.

Some concepts for multi-machine:

  • Rank: A unique sequence number of all the GPUs in all machines
  • local_rank: A sequence number for the GPUs within a machine
  • world_size: A count of all the GPUs in all the machines, which is just the total number of GPUs among all the machines.

Reference

  1. G. Wang — Distributed Machine Learning with Python
  2. Distributed Training: Guide for Data Scientists

--

--

Jonathan Nguyen

Machine learning engineer, deep learning enthusiast. Interested in ML deployment, scaling