Introduction to TorchShard

A Lightweight Library for Scaling-up the Training

Kaiyu Yue
PyTorch
7 min readJul 16, 2021

--

Author: Kaiyu Yue, Incoming Ph.D. Student at the Computer Science Department of the University of Maryland, College Park

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., BERT and GPT) or huge classes (millions). It has the same API design as PyTorch. In this blog, we will introduce TorchShard and illustrate how to adopt it in our projects.

GitHub Repository: https://github.com/KaiyuYue/torchshard

Introduction

Training super large models, such as BERT and GPT, is trending in Natural Language Processing (NLP) applications. One of the biggest problems for training such large models is memory constraints. To crack this nut, Megatron-LM and PyTorch-Lightning use model parallelism to scale up the training. However, Megatron-LM only focuses on training language models at scale. PyTorch-Lightning is only based on sharded optimizer states and gradients like DeepSpeed.

In the computer vision tasks, we would encounter the same problem in training huge transformer-based / MLP-based models or training a model within millions of classes.

The model parallelism could benefit these vision tasks as well. However, a standard library doesn’t exist to enable us to adopt model parallelism as easily as adopting other state-of-the-art (SOTA) techniques such as mixed-precision.

Therefore, TorchShard aims to:

  • be a standard PyTorch extensive library for scaling up training with model parallelism.
  • be used in an easy and natural PyTorch way.

TorchShard is a ground-up rewrite of the model parallel unit (mpu) — the core of Megatron-LM. FairScale also forks the mpu to extend its training capabilities with other SOTA extensions (Shard Data-Parallel and ZeRO). However, it still integrates the mpu as an inner library and lacks documents to illustrate how to use it. In contrast, TorchShard has both documents and tutorials.

More than anything, TorchShard has the same API design as PyTorch, which means that all the sub-classes and sub-functions keep the same as those of PyTorch. For example, if you would like to make the original linear layer torch.nn.Linear be parallel, just change torch into ts and call the sub-class nn.ParallelLinear with a dim parameter. It looks like:

Besides, TorchShard supports all kinds of features for working with DDP, handling the cases of saving and loading shard checkpoints, initializing shard parameters, and playing with tensors across multiple machines and GPUs. Specifically,

  • torchshard contains essential functions and operations, like torch package.
  • torchshard.nn contains basic building blocks for graphs, like torch.nn package.
  • torchshard.nn.functional contains corresponding functional ops of torchshard.nn . It is like torch.nn.functional package.
  • torchshard.distributed contains basic functions for processing distributed tensors and groups. It is easier to be used than torch.distributed package.

More details of API usage can be found in TorchShard API Documents.

Getting Started with TorchShard

Before installing, please make sure you are aware of installation requirements:

  • Python >= 3.6 and PyTorch >= 1.9.0

Then the TorchShard library can be installed via pip:

pip install torchshard

Here we take the training ResNet-50 on ImageNet as our example for showing how to adopt TorchShard in the project with a few code lines. Commonly the ResNet-50 design paradigm has two parts: convolutional blocks and a fully connected layer, shown in Figure. 1. Generally, the last linear layer has more parameters than a convolutional block due to the massive classes depending on datasets. So here, we slice the last linear layer to check out its maximum size.

Figure 1. Forward flow of training with DDP and DDP + TorchShard.

In the Figure. 1, the left showcases the traditional training paradigm with DDP. Assume we have two ranks, DDP will force each rank to have duplicated model parameters. Whereas, TorchShard will slice the layer’s parameters onto different ranks to reduce the whole GPU memory.

Let’s add some codes to the official ImageNet training script. The modified version has been a part of torchshard projects, and please check out the torchshard/projects.

Firstly, importing torchshard.

import torchshard as ts

Then, we need to initialize the process groups for model parallelism, just like the way to initialize DDP process groups. Only one functional parameter must be set to tell torchshard how many shards should be sliced out from the target layer.

ts.distributed.init_process_group(group_size=args.world_size)

Next, we start to convert the model into a parallel version. We can directly feed the whole model into the convert helper function without special processes.

import resnetmodel = resnet.__dict__[args.arch](pretrained=args.pretrained)
ts.nn.ParallelLinear.convert_parallel_linear(
model, dim=args.model_parallel_dim
)
print("=> paralleling model '{}'".format(args.arch))

Don’t forget the loss function, torchshard.nn.ParallelCrossEntropy , which can switch the running modes between the original PyTorch version and parallel version according to its input tensors. For instance, if the input tensors are produced by the torchshard parallel layer, torchshard.nn.ParallelCrossEntropy will calculate loss value in a parallel manner.

criterion = ts.nn.ParallelCrossEntropyLoss().cuda(args.gpu)

When model parallel mode (TorchShard) and data-parallel mode (DDP) work together, we need to take care of input for the parallel layers. Both parameters and training data in each rank are different. Therefore, we gather input tensors before the parallel linear layer in the ResNet forward.

x = ts.distributed.gather(x, dim=0) # gather input along the dim of batch size 
x = self.fc(x)

Same, we gather target tensors before calculating loss values.

output = model(images)
if args.enable_model_parallel:
target = ts.distributed.gather(target, dim=0)
loss = criterion(output, target)

Last, it is super easy to use TorchShard functions to save and load checkpoints. TorchShard offers basic functions named torchshard.collect_state_dict for saving and torchshard.relocate_state_dict for loading.

  • For saving checkpoints:
state_dict = model.state_dict()# collect states across all ranks
state_dict = ts.collect_state_dict(model, state_dict)
if ts.distributed.get_rank() == 0:
torch.save(state_dict, 'resnet50.pt') # save as before
  • For loading checkpoints:
if ts.distributed.get_rank() == 0: 
state_dict = torch.load('resnet50.pt')
# relocate state_dict() for all ranks
state_dict = ts.relocate_state_dict(model, state_dict)
model.load_state_dict(state_dict) # load as before

Now we have finished adding code for the shard training on ImageNet. Then we can scale it up by increasing the number of classes, i.e., the output feature dimension of the last linear layer. The training scripts can be found in the torchshard/project/imagenet. The following figure showcases the scaling-up ability of Training ResNet-50 on 8 NVIDIA TITAN-XP (12196 MiB) GPUs for the number of classes ≤ 1000000, and 16 GPUs for 2000000.

Figure 2. GPU memory cost with the standard ResNet training settings (i.e., input size 224 and batch size 256) under different parallel strategies.

Working with AMP and ZeRO

TorchShard works in an easy and natural PyTorch way with other techniques, such as auto-mixed precision (AMP) and ZeRO.

Please refer to the PyTorch AMP tutorial — All together: “Automatic Mixed Precision.” In our project code, it looks like this.

Figure 3. GPU memory cost with the standard ResNet training settings (i.e., input size 224 and batch size 256) under different parallel strategies and AMP.

ZeRO is the core of DeepSpeed. Please refer to the PyTorch distributed optimizer ZeroRedundancyOptimizer. It will come with PyTorch >= 1.9.0. If you would like to test this function, please install the latest version to run the script. The code looks like this.

Figure 4. GPU memory cost with the standard ResNet training settings (i.e., input size 224 and batch size 256) under different parallel strategies and ZeRO optimizer.

Moreover, TorchShard offers basic python APIs and corresponding template files to ease implementations of our customized parallel layers. For example, we have a comprehensive tutorial on how to write a customized parallel layer for face training.

Looking Ahead

TorchShard will be progressively developed with more new futures. For example, the coming next feature is a new data sampler — torchshard.utils.data.DistributedGroupSampler , which is named following torch.utils.data.DistributedSampler . DistributedGroupSampler aims to help users build the M-way data-parallel and N-way model parallel as easily as DistributedSampler with DDP. The only thing to do for users is to set the model parallel-group number. Then the DistributedGroupSampler will make sure that modules in the same model parallel-group have the same training data.

If you are interested in torchshard, you are welcome to help:

  • polish code and develop new features.
  • develop high-quality tutorials, projects, and advanced materials.

We would love to hear your feedback, receive your contributions and stay connected using GitHub issues.

Thanks for reading our blog post, and hope it helps.

References

  • B. Lefaudeux, et al. FairScale: PyTorch Extensions For High Performance And Large Scale Training, 2020. [GitHub]
  • S. Rajbhandari, et al. ZeRO: Memory Optimizations Toward Training Trillion Parameter Models, 2019. [arXiv] [GitHub]
  • M. Shoeybi, et al. Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism, 2019. [arXiv] [GitHub]
  • J. Rasley, et al. DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters, 2020. [Webpage] [GitHub]
  • W. Falcon, et al. PyTorch-Lightning: The Lightweight PyTorch Wrapper For High-performance AI Research, 2019. [GitHub]
  • N. Shazeer, et al. Mesh TensorFlow: Model Parallelism Made Easier, 2018. [arXiv] [GitHub]

--

--