Stoke — Providing a No-Code Configuration Based Accelerator Playground for PyTorch

Nicholas Cilfone
PyTorch
Published in
12 min readOct 27, 2021
GitHub Repository: https://github.com/fidelity/stoke

Authors: Nicholas Cilfone (GitHub), Principal Data Scientist, AI Center of Excellence @ Fidelity Investments

Across an enterprise-grade organization code quality and compute resources can be quite disparate. We found ourselves in a situation where we needed to support both on-prem GPU compute clusters (e.g. a k8s cluster with multiple attached Nvidia DGX-1 nodes) and cloud-based GPU solutions (e.g. AWS SageMaker) when building and training deep learning models. Given the architecture(s) at the time (especially AWS SageMaker’s dependence on Horovod and the networking of our on-prem cluster) this required supporting both PyTorch DP and DDP paradigms as well as Horovod via OpenMPI. On top of this, many software/optimization based ‘accelerators’ built on top of PyTorch started to emerge (e.g. DeepSpeed circa late 2019).

Supporting experimentation within this ever-changing landscape of devices (e.g. CPU, GPU, TPU, etc.), distributed methodologies (e.g. PyTorch DDP, Horovod, etc.), mixed precision (e.g. Nvidia Apex, Pytorch AMP), and software/optimization ‘accelerators’ (e.g. DeepSpeed ZeRO 0–3) started to become quite unwieldy. Each seemingly required it’s own slightly different syntax, logic, or context. For instance, experimenting with the two common mixed precision methods, Nvidia’s APEX and PyTorch’s native AMP (which was released with PyTorch 1.6.0 in July of 2020), required a significant amount of embedded boolean logic such as the following example (Note: self.config.amp can either take amp or apex which flips between the two methods):

Example boolean logic in initialization for switching between PyTorch AMP and Nvidia APEX

Here, when instantiating a class that’s going to handle all of the training operations we needed to create a torch.cuda.amp.GradScaler if we wanted to use PyTorch AMP. However, if we wanted to use Nvidia's APEX we needed both the already instantiated model of type torch.nn.Module, an optimizer of type torch.optim.Optimizer and need to disregard the torch.cuda.amp.GradScaler. In addition, if we wanted to run with full-precision we need to make sure the torch.cuda.amp.GradScaler is still instantiated but enabled=False.

Then when calculating the gradients (with gradient clipping for more illustrative purposes) for a given batch, the code looks something like the following:

Calculating gradient using PyTorch AMP vs. Nvidia APEX

The above code encompasses the fundamental unit of training a deep learning model with PyTorch. Getting a mini-batch, calculating the gradients, and then taking a step with the optimizer based on those gradients. Hopefully it’s clear how convoluted it becomes when having to deal with the boolean logic required for switching between the two mixed precision methods. AMP needs a context manager torch.cuda.amp.autocast that wraps the forward() call, it requires direct management of the torch.cuda.amp.GradScaler object (especially for gradient clipping), and it changes the signature for the step() call. APEX forgoes the direct management of the scaler and context manager but requires calls to its native amp methods to get optimizer parameters.

A Combinatorics Nightmare

From looking at this simple example, one can see that there really isn’t much common ground between the two libraries, even though they both tackle the same ‘accelerator’ functionality of mixed precision. Now tack on all the other ‘accelerator’ libraries/frameworks across devices (e.g. CPU, GPU, TPU, etc.) that one might want to experiment with: distributed methodologies (e.g. PyTorch DDP, Horovod, DeepSpeed via DDP, etc.), mixed precision (e.g. Nvidia Apex, Pytorch AMP, DeepSpeed’s custom APEX implementation), and other ‘accelerator extensions’ (e.g. DeepSpeed ZeRO 0–3, Fairscale OSS, SDDP). It becomes pretty obvious that the combinatorics problem of being able to configure each one of these ‘accelerators’ together is extremely tedious and laborious.

Enter PyTorch Lightning. Initially released in 2019, “Lightning disentangles PyTorch code to decouple the science from the engineering”. It is an opinionated library that helps to remove the common boilerplate code normally necessary to train deep learning model in PyTorch. Large scale adoption followed and it has become an invaluable tool within the PyTorch Ecosystem. Most importantly, it provides an additional API called Accelerators that helps manage switching between devices (CPU, GPU, TPU), mixed-precision (PyTorch AMP and Nvidia’s APEX), and distributed backends (PyTorch DDP, Horovod) if you are using the Lightning ecosystem.

However, a large portion of our code-bases were written prior to the release of Lightning, or, even after the fact, just didn’t utilize Lightning and were pure PyTorch based. Therefore, we started with a simple question:

Could we take the best of Lightning Accelerators and disconnect it from the rest of PyTorch Lightning?

The Journey Developing Stoke

We started by using the fundamentals of Lightning Accelerators as inspiration (it supported the ‘accelerators’ of PyTorch DDP, Horovod, AMP, and APEX at the time)… Allow a user to mix and match ‘accelerator’ frameworks/libraries by simply changing a few declarative flags. Based on the declarative flags, Stoke builds an internal state representation of the available ‘accelerators’ and uses mixin style classes to dynamically build an object that correctly supports the underlying declared state of the ‘accelerators’. It also automatically handles gradient accumulation, gradient clipping, model i/o, and device placement.

As Stoke grew, we realized that there was also a lack of access to the configuration settings for each of the underlying ‘accelerator’ backends as well as within Lightning. In addition, most documentation of the configuration settings were still buried in the docstrings/docs of each individual framework/library. Therefore, we designed a simple and unified approach for configuring each supported ‘accelerator’ library/framework using the attrs library (see all the config classes here).

For instance, these classes allow you to easily configure the underlying PyTorch AMP or Nvidia APEX backend:

Example Stoke configurations for the underlying backends of PyTorch AMP and Nvidia APEX

We reached an internal alpha version of Stoke in late 2020 to early 2021, that essentially replicated the functionality of Lightning Accelerators. Roughly around the same time, Fairscale started to gain traction with support for Optimizer State Sharding (an implementation similar to what DeepSpeed established with its ZeRO implementations) while PyTorch Lightning added support for DeepSpeed as another ‘accelerator’. We realized that Stoke could have another functional role as an open-source library, not only providing functionality akin to Lightning Accelerators that is more easily configurable, but also helping deal with the empirical nature of the current state of training deep learning models. Let’s go on a quick tangent to explain…

Most ‘accelerator’ methods are essentially ‘shortcuts’ around what can be described as the ‘infinite computing needed’ problem (thanks to Zach Semenov for this idea). In fact, the most common technique we use for training deep learning models is a ‘shortcut’ to this exact problem. Mini-batch gradient descent breaks a very large training dataset into much smaller batches and calculates the gradients across this smaller set of samples instead of calculating the gradients on the entire training dataset. Why do we do this…? Because the entire training dataset won’t fit in memory on most machines/devices and would need some weird ‘infinite computing’ requirement. So we use an approximation to the gradients via a sampling technique with no guaranteed bounds. Mixed precision is simply a ‘shortcut’ that reduces the floating point precision, and thus memory requirements, of numeric representations on the assumption that the ‘lost’ precision won’t have significant effects on the optimizer traversing the optimization landscape. Optimizer State Sharding is simply a communication ‘shortcut’ for memory management by efficiently partitioning the gradients across multiple devices instead of maintaining redundant copies of the same data on each device.

The effects of most of these ‘shortcut’ methods on model training converge largely to one simple place, effective batch size (effective batch size = per_device_batch_size * num_devices). Using a distributed backend (e.g. DDP or Horovod) to horizontally scale training with multiple GPU devices changes the ‘effective batch size’ as the total number of devices typically increases. Decreasing the memory footprint of the optimizer via Optimizer State Sharding typically allows the practitioner to increase the per device batch size, thus also increasing the effective batch size. Each of the ‘shortcuts’ scale in its own unique way with respect to effective batch size. However, anyone who has tried to train large-scale deep learning models knows that increasing the effective batch size (say from 32 on a single GPU to 96 on two GPUs with Optimizer State Sharding) means changes to learning rate, learning rate schedules, etc. and most likely alters model convergence and thus performance. So, long tangent short:

It’s quite hard to know a priori what ‘shortcuts’ are going to lead to the desired outcome… simple and quick experimentation is critical to being able to leverage the ‘right shortcuts’

Therefore, we realized the fundamental value-proposition of Stoke (and to some extent Lightning Accelerators) as a library is to provide an easily accessible and configurable ‘Accelerator Playground’ for PyTorch, where any existing or newly-emerging ‘accelerator’ can quickly and easily be tested for usefulness. Therefore, the mantra for Stoke became:

No spaghetti-code, no model implementation re-writes, no complex embedded boolean logic. Just a simple declarative flag to ask ‘will this accelerator work with my model’?

What Stoke Provides

Stoke is a lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices (e.g. CPU, GPU), distributed modes, mixed-precision, and other PyTorch ‘accelerator’ extensions. It places no restrictions on code structure/style for model architecture, training/inference loops, loss functions, optimizer algorithm, etc. It simply ‘wraps’ your existing PyTorch code to automatically handle the necessary underlying wiring for all of the supported ‘accelerators’. This allows you to switch from local full-precision CPU to mixed-precision distributed multi-GPU with optimizer state sharding by simply changing a few declarative flags. In short, the main benefits are:

  • Declarative style API: declare the desired accelerator state(s) and let stoke handle the rest
  • Wrapped API mirrors base PyTorch style model, loss, backward, and step calls
  • Automatic device placement of model(s) and data
  • Universal interface for saving and loading regardless of backend(s) or device(s)
  • Automatic handling of gradient accumulation and clipping
  • Common attrs interface for all backend configuration parameters (with helpful docstrings!)
  • A few extra(s) — Custom torch.utils.data.distributed.Sampler: BucketedDistributedSampler which buckets data by a sorted index and then randomly samples from specific bucket(s) to prevent situations like grossly mismatched sequence length leading to wasted computational overhead (i.e. excess padding). Helper methods for printing synced losses, device specific print, number of model parameters, etc.

Remember the code to switch between AMP and APEX in the Motivation section above? Let’s do a simple comparison of what similar code would look like in a simple script…

Without Stoke:

Base PyTorch code that is switchable between AMP and APEX (APEX flagged)

With Stoke:

Stoke based code that is switchable between AMP and APEX (APEX flagged)

Stoke supports the following ‘accelerators’:

Certain combinations of backends/functionality are not compatible with each other. The below table indicates which combinations have been tested together:

Stoke Compatibility Matrix

Building a CIFAR10 Model with Stoke

(Note: A full working example can be found here)

The CIFAR-10 dataset consists of 60000 32x32 color images in 10 classes, with 6K images per class. There are 50K training images and 10K test images. For simplicity, we borrow model code from Torchvision and use the built-in ResNet-152 model (arxiv). Let’s start by getting this up and running with no bells and whistles running just on a CPU (maybe mimicking your local dev environment). For simplicity we will hard-code most model, optimizer, etc. parameters (checkout Spock for a helpful parameter configuration library). Let’s create the model, optimizer, and loss function. Stoke requires a slightly different way to define the optimizer (as it handles instantiation internally) by using StokeOptimizer. Pass in the uninstantiated torch.optim.* class object and any **kwargs that need to be passed to the __init__ call:

Now, we create the Stoke object that will be the primary interface to the base PyTorch style API functions: model, loss, backward, and step. Here since we are just using the CPU with no additional 'accelerators' the instantiation is pretty simple. We pass in the model, optimizer, loss function, and batch size. We also set the verbosity so that we get some useful debugging output.

Next, make some pipelines for the image datasets and then retrieve them using using torchvision.datasets:

Finally, we need to create a torch.utils.data.DataLoader object. Similar to the optimizer definition this has to be done a little differently with Stoke for it to correctly handle each of the different backends. The main Stoke object provides a mirrored wrapper to the native torch.utils.data.DataLoader class (as the DataLoader method) that will return a correctly configured torch.utils.data.DataLoader object.

At this point, we’ve successfully configured Stoke to run on a single CPU. The following simple training loop should look fairly standard, except that the model forward, loss, backward, and step calls are all called on the Stoke object instead of each individual component (as it internally maintains the model, loss, and optimizer and all necessary code for all backends/functionality/extensions). We also use some of the built in print functionality of Stoke:

Let’s make sure to save the model as well. Conveniently, stoke provides a unified interface to save and load model checkpoints regardless of backend/functionality/extensions:

Adding ‘Accelerators’ with Stoke

This hopefully is where the ease of configuring ‘accelerators’ with Spock shines. For this example let’s assume we have 4 GPUs available. Let’s use native PyTorch AMP, FairScale’s Sharded DDP, and FairScale’s Optimizer State Sharding. We are also going to configure these methods to our liking. In addition, let’s also add in gradient accumulation and gradient clipping. For this, creating the custom configurations and instantiation of the Stoke object becomes:

When we make the DataLoader we now need to pass in a DistributedSampler since we are using a distributed method. Since the Stoke object manages the backend(s) we can actually pass the world_size and rank directly from the object regardless of the distributed backend (i.e. Horovod and PyTorch DDP require different semantics without Stoke):

Since Stoke handles wrapping/building your torch.nn.Module and torch.utils.data.DataLoader, device placement is handled automatically (in this example the model and data are moved to correct GPUs).

The training loop and model save code remain the exact same as the single CPU case. Notice what hasn’t changed… anything related to your model code. And only one change to your data pipeline (which could be removed with a simple if statement). The model went from running on a single CPU to 4 GPUs in a custom configured sharded DDP fashion using PyTorch AMP and Optimizer State Sharding with a few additional declarative configurations/statements.

Now let’s imagine another scenario. You’ve now got access to 8 GPUs but they are split across 4 different clusters. You have access to Horovod via OpenMPI and you want to test Nvidia’s APEX O1 implementation instead of PyTorch AMP, but turn off all the Fairscale extensions. It’s as simple as changing your declarative options to the following:

Voila! A completely different set of ‘accelerators’ through simple changes to declarative configurations/statements!

Closing & Looking Forward

Having spent almost a year building and testing Stoke internally at Fidelity, we are happy to release it as an open-source a tool for the broader ML community to use. Development will continue in the open-source domain. We envision Stoke as the go-to ‘Accelerator Playground’ with its unified interface, simple configuration API and continued support for any cutting-edge PyTorch based ‘accelerator’ functionality released in the open-source domain (e.g. the recent support for Full Model Sharding incorporated into Fairscale). We also target maintaining functional parity to the Lightning Accelerators API for those that prefer to just use base PyTorch.

Additionally, we hope that there will be a constant turn-over of ‘accelerator’ features that Stoke currently and eventually supports into GA release versions of PyTorch. For instance, PyTorch 1.8.0 introduced a beta Optimizer State Sharding feature with the addition of torch.distributed.optim.ZeroRedundancyOptimizer (now stable as of 1.10!) that overlaps with some of the functionality supported in Stoke (via Fairscale and DeepSpeed). Stoke will always attempt to be more on the ‘bleeding-edge’ of accelerators and allow for more ‘experimental’ functionality than GA PyTorch.

Near-term vision for the rest of 2021 and into early 2022:

  • Bug squash to get to a V1.0 stable release
  • Unit tests on CPU/GPU hardware for the entire ‘accelerator’ combinatorics space
  • TPU device support
  • ̶F̶a̶i̶r̶s̶c̶a̶l̶e̶ ̶F̶S̶D̶P̶ ̶s̶u̶p̶p̶o̶r̶t̶ (currently on master!)

Thanks for reading!

If you are interested, we welcome contributions from the community! Any contributions to Stoke should come through a submitted pull request. If you come across any bugs/issues please open an issue!

Acknowledgements

  1. Thanks to Jeff Brown, Rich DiBiasio, Zach Semenov, and Amit Shavit for discussions, proof-reading, and making sure this whole article wasn’t just nonsense!
  2. A few other open source libraries from Fidelity that helped drive the creation of Stoke: Spock — a framework that helps manage complex parameter configurations, textwiser — a unified framework for text featurization. Check out Fidelity GitHub all of the available open-source software.
  3. Note: We would be remiss to not mention that HuggingFace Accelerate exists and attempts to fill a similar gap. If you prefer to stay within the HF universe, Accelerate should provide similar capabilities to Stoke, however it is lacking Nvidia Apex, Horovod, and Fairscale support, as well as DeepSpeed still being an experimental feature. In addition it has a more limited configuration interface compared to Stoke.

--

--