an adult cheetah running in a field of grass with a baby cheetah watching
Photo by Sammy Wong. Source: Unsplash

How to Load PyTorch Models 340 Times Faster with Ray

Fred Reiss
IBM Data Science in Practice
8 min readAug 23, 2021

--

One of the challenges of using deep learning in production is managing the cost of loading models for inference. In this article, we’ll show how you can reduce this cost almost to zero by leveraging features of PyTorch and Ray.

Introduction

Deep learning models are big and cumbersome. Because of their size, they take a long time to load. Managing this loading cost requires complex systems for deploying models in production. Model inference platforms like TFX, TorchServe, and IBM Spectrum Conductor Deep Learning Impact run deep learning models inside dedicated, long-lived processes and containers, with additional layers to manage the containers and to pass data between them.

a single virtual or physical machine with torch serve on it with process orientation running. there is a front end connected to management API, inference API, server logs, configuration snapshot, custom metrics reporting, and the backend. the back end has worker processes and model handlers which connect to a model store.
Block diagram of the TorchServe model inference platform, showing how TorchServe assigns a pool of dedicated, long-lived processes to each model in order to amortize model loading costs. These process pools add additional complexity to the layers above and below them. Source: https://github.com/pytorch/serve; License: Apache V2

But what if this conventional wisdom isn’t entirely correct? What if there was a way to load a deep learning model in a tiny fraction of a second? It might open the door to much simpler ways to run deep learning in production.

Let’s see how fast we can make model loading go.

Background: BERT

For the examples in this article, we’ll use the BERT masked language model. BERT belongs to a group of general-purpose models that capture the nuances of human language in a (relatively) compact format. You can use these models to do many different natural language processing (NLP) tasks, ranging from document classification to machine translation. However, to perform any task with high accuracy, you need to start with a model trained on your target language and fine-tune the model for the task.

Tuning a BERT model for a task effectively creates a new model. If your application needs to perform three tasks in three different languages, you’ll need nine copies of BERT: one for each combination of language and task. This proliferation of models creates headaches in production. Being able to load and unload BERT-based models quickly would save a lot of trouble.

Let’s start by loading up a BERT model in the most straightforward way.

Loading a BERT Model

The transformers library from Huggingface provides convenient ways to load different variants of BERT. The code snippet that follows shows how to load bert-base-uncased, a medium-sized model with about 420 MB of parameters.

The transformers.BertModel.from_pretrained() method follows PyTorch's recommended practice for loading models: First, construct an instance of your model, which should be a subclass of torch.nn.Module. Then use torch.load() to load a PyTorch state dictionary of model weights. Finally, call your model's load_state_dict() method to copy the model weights from the state dictionary into your model's torch.Tensor objects.

This method takes about 1.4 seconds to load BERT, provided that the model is on local disk. That’s fairly impressive for a model that’s over 400MB in size, but it’s still a long time. For comparison, running inference with this model only takes a fraction of a second.

The main reason this method is so slow is that it is optimized for reading models in a portable way over a slow network connection. It copies the model’s parameters several times while building the state dictionary, then it copies them some more while installing the weights into the model’s Python object.

PyTorch has an alternate model loading method that gives up some compatibility but only copies model weights once. Here’s what the code to load BERT with that method looks like:

This method loads BERT in 0.125 seconds on the same machine. That’s 11 times faster.

If dropping the number of copies to 1 makes model loading that much faster, imagine what would happen if we dropped the number of copies to zero! Is it possible to do that?

Zero-Copy Model Loading

It turns out that we can indeed load PyTorch models while copying weights zero times. We can achieve this goal by leveraging some features of PyTorch and Ray.

First, some background on Ray. Ray is an open source system for building high-performance distributed applications. One of Ray’s unique features is its main-memory object store, Plasma, which uses shared memory to pass objects between processes on each machine in a Ray cluster. Ray uses Plasma to implement zero-copy transfer of NumPy arrays. If a Ray task needs to read a NumPy array from Plasma, the task can access the array’s data directly out of shared memory without copying any data into its local heap.

If we store the weights of a model as NumPy arrays on Plasma, we can access those weights directly out of Plasma’s shared memory segments, without making any copies.

But we still need to connect those weights to the rest of the PyTorch model, which requires wrapping them in PyTorch Tensor objects. The standard method of creating a Tensor involves copying the contents of the tensor, but PyTorch also has an alternate code path for initializing aTensor without performing a copy. You can access this code path by passing your NumPy array to torch.as_tensor() instead of using Tensor.__new__().

With all of this background information in mind, here’s a high-level overview of how to do zero-copy model loading from Plasma. First, you need to load the model into the Plasma object store, which is a three-step process:

  1. Load the model from disk.
  2. Separate the original PyTorch model into its weights and its graph of operations, and convert the weights to NumPy arrays.
  3. Upload the NumPy arrays and the model (minus weights) to Plasma.

Once the model and its weights are in object storage, it becomes possible to do a zero-copy load of the model. Here are the steps to follow:

  1. Deserialize the model (minus weights) from Plasma
  2. Extract the weights from Plasma (without copying data)
  3. Wrap the weights in PyTorch Tensors (without copying)
  4. Install the weight tensors back in the reconstructed model (without copying)

If a copy of the model is in the local machine’s Plasma shared memory segment, these steps will load load BERT in 0.004 seconds. That’s 340 times faster than loading the model with BertModel.from_pretrained() .

BertModel.from_pretrained() takes 1.4 seconds, torch.load() takes 0.125 seconds, and zero-copy loading takes 0.004 seconds.
Comparison of running times for three different ways of loading a BERT language model. Timings are the average of 100 runs on a MacBook Pro with a 2.3 GHz Intel Core i9 processor. Compared with the standard BertModel.from_pretrained() model loading API, zero-copy model loading reduces the loading time by a factor of 340.

More importantly, this loading time is an order of magnitude less than the time it takes to run one inference request on this model with a general purpose CPU. That means that you can load the model on demand with almost no performance penalty. There’s need to spin up a dedicated model serving platform or a Ray actor pool, tying up resources for models that aren’t currently running inference.

The Details

Let’s break down how to implement each of the steps for zero-copy model loading, starting with getting the model onto Plasma in an appropriate format.

We’ve already covered how to load a PyTorch model from disk. The next step after that initial loading is to separate the model into its weights and its graph of operations, converting the weights to NumPy arrays. Here’s a Python function that will do all those things.

Most PyTorch models are built on top the PyTorch class torch.nn.Module. The model is a graph of Python objects, and every object is a subclasses of Module.

The Module class provides two places to store model weights: parameters for weights that are trained by gradient descent, and buffers for weights that are trained in other ways. Lines 6-17 of the listing above iterate over the components of the model, pull out the parameters and buffers, and convert their values to NumPy arrays. Then lines 21-25 create a copy of the model and remove all the weights from the copy. Finally, line 29 returns the copy and the converted weight tensors as a Python tuple.

We can pass the return value from this function directly to ray.put() to upload the model and its weights onto Plasma. Here's what the upload operation looks like.

The variable bert_ref here is a Ray object reference. We can retrieve the model and weights by passing this object reference to ray.get(), as in the following listing.

If the object that bert_ref points to isn't available on the current node of your Ray cluster, the first attempt to read the model will block while Ray downloads the object to the node's local shared memory segment. Subsequent calls to ray.get(bert_ref) will return the local copy immediately.

Now we need to convert bert_weights from NumPy arrays to torch.Tensor objects and attach them to the model in bert_skeleton, all without performing any additional copies. Here is a Python function that does those steps.

This function does roughly the same thing as PyTorch’s load_state_dict() function, except that it avoids copying tensors. The replace_tensors() function modifies the reconstituted model in place. After calling replace_tensors(), we can run the model, producing the same results as the original copy of the model. Here's some code that shows running a BERT model after loading its weights with replace_tensors().

Caveats

The first time you call the replace_tensors() function, PyTorch will print out a warning:

UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. [...]

Most PyTorch models don’t modify their own weights during inference, but PyTorch doesn’t prevent models from doing so. If you load your weights via the zero-copy method and your model modifies a weights tensor, it will change the copy of those weights in Plasma’s shared memory. Ray (as of version 1.4) always opens shared memory segments in read-write mode.

If you’re sure that you model does not not modify its own weights during inference, you can safely ignore this warning. You can test for these modifications by comparing your model’s weights before and after inference. If your model does modify some of its weights, it’s important to copy the relevant tensors prior to running inference.

Another thing to note is that this method loads the model for CPU-based inference. To use GPU acceleration, you will need to copy the model’s weights once to load them onto GPU memory. This copy operation takes about 0.07 seconds, which is still three times faster than the second-fastest way to load the model onto a GPU.

And one final thing to note: The code above only works for PyTorch. Implementing zero-copy model loading with TensorFlow is possible in theory but much more difficult in practice.

Conclusion

We hope you’ve enjoyed this exploration of zero-copy model loading with Ray and PyTorch. Being able to load models in milliseconds opens up some interesting architectural possibilities for high performance model inference. If you’d like to find out more about what IBM Research is doing with Ray, check out the CodeFlare project.

--

--

Fred Reiss
IBM Data Science in Practice

Fred Reiss is a Principal Research Staff Member at IBM Research and Chief Architect at IBM’s Center for Open-Source Data and AI Technologies (CODAIT).