Sitemap
TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Sharded: A New Technique To Double The Size Of PyTorch Models

5 min readDec 12, 2020

--

Giving it scale (Photo by Peter Gonzalez on Unsplash)

Deep learning models have been shown to improve with more data and more parameters. Even with the latest GPT-3 model from Open AI which uses 175B parameters, we have yet to see models plateau as the number of parameters grow.

For some domains like NLP, the workhorse model has been the Transformer which requires massive amounts of GPU memory. For realistic models, they just don’t fit in memory. The latest technique called Sharded was introduced by Microsoft’s Zero paper in which they develop a technique to bring us closer to 1 trillion parameters.

In this article, I will give the intuition behind sharded, and show you how to leverage this with PyTorch today to train models with twice the memory in just a few minutes.

This capability in PyTorch is now available thanks to a collaboration between Facebook AI Research’s FairScale team and the PyTorch Lightning team.

By the way, I write about the latest in deep learning, explain intuition behind methods and tricks to optimize PyTorch. If you enjoy this type of articles, follow me on twitter for more content like this!

Outline

  • Who is this article for?
  • How to use sharded with PyTorch
  • Intuition behind sharded
  • Sharded vs model parallel

Who is this article for?

This article is for anyone using PyTorch to train models. Sharded works on any model no matter what type of model it is, NLP (transformer), vision (SIMCL, Swav, Resnets, and even Speech.

Here’s a quick snapshot of the performance gains you can see with sharded across these model types.

Image source: (Sean Narenth) (with modifications)

SwAV is the state of the art method for self-supervised learning in computer vision.

DeepSpeech2 is a state of the art method for speech.

Image GPT is a state of the art method for vision.

Transformer is a state of the art method for NLP.

How To Use Sharded With PyTorch

For those with not much time to read through the intuition of how sharded works, I’m going to explain up-front how to use sharded with your PyTorch code. But I encourage you to read through the end of the article to understand how sharded works.

Sharded is meant to be used with multiple GPUs to gain all the benefits. But, training on multiple GPUs can be intimidating and a huge pain to set up.

The easiest way to supercharge your code with sharded is to convert your model to PyTorch Lightning (which is just a simple refactor). Here’s a quick 4-minute video that shows how to convert your PyTorch code to Lightning.

Once you’ve done that, enabling sharding on 8 GPUs is as simple as changing one flag because there is NO change required to your code.

If your model comes from another deep learning libary, it will still work with Lightning (NVIDIA Nemo, fast.ai, huggingface transformers). All you need to do is import that model into a LightningModule and hit train.

Intuition Behind Sharded

Training efficiently across many GPUs has a few approaches. In one approach, (DP), each batch is split across GPUs. Here is an illustation of DP where each part of the batch goes to a different GPU and the model is copied many times to each GPU.

DP training (Author’s own)

However, this approach is bad because model weights are transferred across device. In addition, the first GPU maintains all the optimizer states. For example, Adam keeps a full extra copy of your model weights.

In another method (distributed data parallel, DDP), each GPU trains on a subset of data, and gradients are synced across GPUs. This method also works across many machines (nodes). In this illustration, each GPU gets a subset of the data and initializes the model weights exactly the same across every GPU. Then after the backward pass, all gradients are synced and updated.

Distributed data parallel (Author’s own)

However, this method still has the problem that every GPU must maintain a copy of all the optimizer states (roughly 2–3x the number of model parameters) and also all the forward and backward activations.

Sharded removes these redundancies. It works the same way as DDP except that all the overhead (gradients, optimizer state, etc) are calculated only for a portion of the full parameters and thus we remove the redundancy of storing the same gradient and optimizer states on all GPUs.

So, each GPU stores only a subset of activations, optimizer parameters and gradient computations.

Using any distributed mode

Switching distributed modes is trivial in PyTorch Lightning

As you see, there are many ways of squeezing maximum efficiency in distributed training by using any of these optimization approaches.

The good news is that all of these modes are available in PyTorch Lightning with zero code changes required. You can try any of them and adjust based on your particular model as needed.

One method that is missing is model parallel. However, a word of caution on model parallel is that it was found to be much more inefficient than sharded training and should be used with caution. For some cases it might work well, but you’re overall best served just using sharded instead.

An advantage of using Lightning is that you’ll never fall behind the latest advancements in AI research! The team and opensource community are dedicated to bringing you the latest advances via Lightning.

Acknowledgments:

Sharded is now available in PyTorch Lightning thanks to the efforts of the Facebook AI FairScale team, special thanks to Benjamin Lefaudeux from Facebook AI FairScale and Sean Narenthiran from PyTorch Lightning / Grid AI.

Sharded was inspired from Microsoft’s Zero paper.

--

--

TDS Archive
TDS Archive

Published in TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

William Falcon
William Falcon

Written by William Falcon

⚡️PyTorch Lightning Creator • PhD Student, AI (NYU, Facebook AI research).

Responses (2)