Photo by Fred Reiss

Easier Model Serving with zerocopy

Fred Reiss
IBM Data Science in Practice

--

Serving deep learning models is hard, but it doesn’t have to be that way.

In a previous post, we introduced the concept of zero-copy model loading. Zero-copy model loading involves keeping the weights of a machine learning model in shared memory so that different processes can load the model for inference instantly without copying data.

We showed that the Plasma object store integrated into Ray makes it easy to do zero-copy model loading for PyTorch models, and that implementing this technique on Ray can accelerate model loading by several orders of magnitude. If you’d like to find out more about the details of zero-copy model loading, follow this link to view the previous post.

In this follow-on post, we focus on how to use zero-copy model loading in deep learning model serving. We introduce zerocopy, a Python package that makes it extra simple to apply the technique with PyTorch and Ray. We show how to deploy models for inference using our library. And then we present an end-to-end model serving benchmark that shows how we can serve multiple large NLP models with a single cloud VM and achieve 7x better scalability with no tuning.

Why is Model Serving Hard?

The architecture of model serving systems stems largely from a single design constraint: Loading large models is expensive.

To run inference on a model, you need to have the model’s weights loaded into a process’s memory. But loading a modern deep learning model is orders of magnitude more expensive than running inference. So you need to keep the model perpetually loaded in a process’s memory. Designing a deep learning model inference system reduces to answering the question, “How do I manage models running in large, persistent processes?”

Different systems take different approaches to this problem. Some, like TorchServe and Ray Serve, run one process per model and have components for managing the resulting process pool. Others, such as TensorFlow Serving and KServe’s ModelMesh, maintain a pool of large processes and have components for mapping models to processes. And other systems like Seldon Core and the KServe’s ModelServer and MLServer run each model in its own container and have components for managing a container pool.

Block diagram for the TorchServe model inference engine, which runs a single process per model. Note the layers of complexity required to manage the resulting persistent process pools.

A Different Approach

Regardless of the path that you take, the journey to serving deep learning models starts out at “loading models is expensive” and ends up in a place with lots of moving parts and lots of knobs to tune.

But what if we could start that journey by moving in a different direction?

That’s the path we’ve been exploring with zero-copy model loading. Zero-copy model loading makes loading a deep learning model nearly instantaneous. As long as the weights are in the local segment of the Ray’s Plasma object store, your code can load a copy of the model instantly, run a single inference request, and then unload the model. Model inference becomes a stateless process. This statelessness removes design constraints that drive complexity in model serving systems. There’s no need to manage a pool of processes or containers, because inference takes place in ephemeral tasks.

Introducing zerocopy

Our previous post included code snippets that show how to rewrite a PyTorch model to use zero-copy model loading. We’ve recently released a Python package, zerocopy, that lets you apply this technique to your models without having to copy and paste Python code. This package is part of IBM's Project Codeflare, a framework to simplify the integration, scaling and acceleration of complex multi-step analytics and machine learning pipelines.

You can install this package by typing:

pip install zerocopy

Once it’s installed, using the zerocopy package is a three-step process:

  1. Import the package.
  2. Move your model’s weights onto the Plasma object store.
  3. Run your model in an asynchronous Ray task.

Let’s show these three steps in action.

Step 1 is just a Python import statement:

import zerocopy

Then it’s on to step 2: Moving your model’s weights onto Plasma. You will of course need a PyTorch model to do this step. As an example, let’s load up the most popular intent detection model from the Huggingface model marketplace.

model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
'mrm8488/t5-base-finetuned-e2m-intent')

To move this model’s weights onto Plasma, you first need to pass the model through zerocopy.extract_tensors(), which separates the weights from the model's Python objects. Then you need to copy the model and its weights to Plasma using the function ray.put(). You can do both of these operations with a single line of Python code.

model_ref = ray.put(zerocopy.extract_tensors(model))

The return value from ray.put() is a Ray object reference. This object reference lets you load the model almost instantly from any location on your Ray cluster. This capability is what enables step 3: Running your model in an asynchronous Ray task.

In our previous post, we showed how you can define a stateless Ray task that loads the model, runs inference over an input, and returns the result. The zerocopy package includes a built-in function call_model() that lets you do all these steps in one line of code.

# Invoke the model's `generate()` method from a remote Ray task
result_ref = zerocopy.call_model.remote(model_ref, [], model_input,
'generate')

As with any other Ray task, call_model.remote() returns a future — a Ray object reference to the place where the result will appear once the task has completed. You can retrieve the result with ray.get().

The time to invoke the rewritten model is almost the same as running the model locally. If you run inference multiple times, zero_copy.call_model() can send those inference requests to separate Ray tasks that run in parallel.

Some Caveats

In order for the model to load instantly, a copy of the serialized model and its weights needs to be in the local node’s Plasma shared memory segment. If Ray’s scheduler schedules a task that uses the model onto a node that does not currently have the model’s data locally, the task will not run until Ray has replicated the relevant objects.

Fortunately, keeping a serialized model on Plasma’s shared memory is less memory-intensive than keeping a copy of the same model in a Python process’s heap memory. For example, the serialized data for the intent detection model we have been using occupies a total of 990 MB in Plasma. We wrote a test script that loads this model and performs a single inference request. When we measured the script’s peak memory usage, we found it to be 2350 MB, almost 2.4 times larger than the zero-copy memory footprint.

Where does this factor of 2.4 come from? There are three primary reasons. The first factor is the additional size of Python and C++ objects and bytecode, compared with the serialized Python model. The second factor is fragmentation and buffer reuse in PyTorch’s memory manager. And the third factor is the space that PyTorch needs for the temporary objects it creates while loading the model or running inference.

Even with this smaller footprint, the serialized models still need to fit within the amount of system memory that is allocated to the object store, so there is an upper limit on the number of models that can be kept at the ready. You may be able to use Plasma’s support for spilling to local disk to increase this number by paging out less frequently used models, although the current version of Ray only provides limited control over this spilling.

Serving Models with Zero-Copy Model Loading

We’ve just shown how the zerocopy library lets you transform code that loads and runs a PyTorch model into code that copies the model to the Plasma object store and farms out inference requests to ephemeral Ray tasks.

This process lets you transform model serving code into zero-copy model serving code. The pair of listings below shows how, by changing two lines of code, we can turn a Ray Serve deployment of our example intent model into a deployment that uses zero-copy model loading.

A Ray Serve model deployment before and after applying zero copy model loading. Only two lines of code need to be changed. See the “Deploying the Models” section for links to both code listings.
A comparison of the code to deploy our example intent detection model, before and after applying zero-copy model loading to the deployment. Only two lines of code need to change. See below for a detailed description of the changes.

Let’s zoom in on the two lines that change.

The Ray Serve deployment on the left attaches the model to a Ray actor, a persistent Python object that resides in a dedicated process. The actor also handles incoming HTTP requests.

The constructor for the actor class contains this line of code, which loads the model from disk and stores a pointer to the resulting Python object.

self._model = transformers.AutoModelForSeq2SeqLM.from_pretrained(
'mrm8488/t5-base-finetuned-e2m-intent')

In the zero-copy version of the Ray Serve deployment, we replace this line of code with code that loads the model, copies the model to Plasma, and stores a Ray object reference locally.

self._model_ref = zerocopy.extract_tensors(
transformers.AutoModelForSeq2SeqLM
.from_pretrained('mrm8488/t5-base-finetuned-e2m-intent'))

The second line that changes is in the actor’s __call__() method, which handles HTTP requests. The original version of this method contains this line of code, which invokes the generate() method on the PyTorch model directly.

raw_output = self._model.generate(**tokens)

We replace this line of code with a line that uses the zerocopy.call_model() function to run model inference in a stateless Ray background task.

raw_output = await zerocopy.call_model.remote(
self._model_ref, [], tokens, 'generate')

Note the use of the await keyword. The revised __call__() method uses Python's asyncio framework to handle multiple simultaneous requests at once. Each active request will use a separate Ray task to run model inference.

These two changes change the Ray actor from a large, heavyweight process into a much more lightweight process that only performs HTTP request handling and data preprocessing. All the work of inference occurs in stateless background tasks.

A Simple Benchmark

We created a simple benchmark to test how well a model deployed with zero-copy model loading can adjust its configuration on the fly to handle a dynamic, bursty stream of application requests. The end-to-end scenario for our benchmark involves supporting an AI chatbot for customer care. The chatbot’s conversational AI uses a preprogrammed dialog flow to control the interaction with the customer. Some of the nodes of this flow use machine learning models to guide their decisions.

Our benchmark covers the model serving portion of the chatbot’s backend. This model serving layer runs four different types of models:

  • Intent detection models that determine what is the user’s goal.
  • Sentiment analysis models that monitor the user’s mood.
  • Question answering models that provide the answers to specific factual questions.
  • Natural language generation models that give the chatbot’s responses a less scripted flavor.

Because the chatbot speaks 3 different languages, there are three versions of each model deployed: one for each language. So the model serving layer runs a total of 12 models.

In a real application, you would want to train custom versions of each type of model for the topics your chatbot covers. Since we’re only interested in measuring throughput and latency, we skipped that customization step and just used the most popular pretrained model from each category from the Huggingface model marketplace.

Each of these models uses a Transformer-based neural network, with a language model and a task specific head, tuned over a domain-specific training set. The table below summarizes the four models that we used.

Although all four models came from the same marketplace, they are quite diverse. The models use three different core language models: Text-to-Text Transfer Transformer (T5) from Google Research, RoBERTa from Facebook AI, and GPT-2 from OpenAI. The models vary in size by almost a factor of 3.

Deploying the models

As a baseline for comparison, we deployed each of the four model types using Ray Serve endpoints. The intent detection model in the benchmark is the same model that we have been using in our example code so far in this article, so we used the code above to deploy that model. We followed a similar process for the remaining three model types to deploy a total of 12 models. You can find the full code that we used to deploy all 12 models in this notebook.

We then modified this baseline deployment code to use zero-copy model loading. For each of our Ray Serve deployments, we changed two lines of code as we showed earlier in this post. Then we deployed three copies of each model for a total of 12 zero-copy model deployments. The code we used can be found in this notebook.

Running the Benchmark

We wrote a simple discrete event simulation to simulate a variable number of customers interacting with the chatbot. Each simulated customer types a series of chat messages, waiting for a randomly-distributed “think time” between messages. Our simulation draws these think times randomly from a Poisson distribution with a mean value of 10 seconds. Each chat message results in a single model inference request, with the choice of models drawn randomly from another Poisson distribution.

The benchmark runs the simulation to generate a trace, then plays back the trace, sending requests to the backend under test and measuring the end-to-end latency of each request. We repeat this process of generating and playing back the trace, gradually ramping up the average request rate of the bursty traffic until requests start timing out. We used a timeout threshold of 5 seconds.

We ran this benchmark against our two model deployments, using the same trace of requests. Both runs used the same hardware, a 16-core IBM Cloud VM, to run the entire benchmark, including the client portion of the benchmark, the serving framework, and the processes that performed model inference.

The code that we used to implement the benchmark can be found here.

Benchmark Results

The chart below shows the results of this benchmark. The X axis of the chart measures the number of simulated customers interacting with the chatbot. The Y axis measures what fraction of users’ chat messages exceed the 5-second timeout limit.

The baseline deployment can handle up to 30 simultaneous chat sessions. Beyond 30 sessions, the number of timeouts increases rapidly and the system becomes unstable. The baseline deployment uses a fixed allocation of CPU and memory between models because it hasn’t been extensively tuned. When there are more than 30 simulated customers on the line, the CPU resources that are dedicated to the model with the highest traffic can no longer keep up with the increasingly large bursts of chat messages.

Our deployment with zero-copy model loading, on the other hand, handles up to 220 sessions without any timeouts. This performance represents a 7x improvement in scalability without any tuning. The zero-copy based deployment is able to instantly redirect the underlying hardware resources to whatever models are currently experiencing high traffic.

This result shows how zero-copy model loading gives you zero-effort performance tuning.

All the code used in this blog post is available on Project CodeFlare’s Github repository here.

--

--

Fred Reiss
IBM Data Science in Practice

Fred Reiss is a Principal Research Staff Member at IBM Research and Chief Architect at IBM’s Center for Open-Source Data and AI Technologies (CODAIT).