At Snaptravel, we have built a chatbot that allows people to book hotels. This article is not about how we built the models, we already have a paper for that. Instead, this is about how we have taken the existing models that are built in AllenNLP/PyTorch, modifying them slightly, and productionizing them to reduce memory usage. As a result, we have scaled down the number of EC2 instances as we are using each instance to greater capacity.
Inefficient Use of PyTorch in Production
We have 4 Machine Learning models, all of which are written with AllenNLP. We have deployed all the models within Gunicorn + Flask, so on server startup all the models are loaded into memory (see Figure 1). Each of our model runs BERT-base + layers which takes approx ~1GB of memory. Due to the high traffic we receive, just 1 worker in 1 instance is not enough. Since we are working in AWS, this means we need either 2 EC2 large instances, or 1 EC2 xlarge instance with 2 workers. If we attempt to set up gunicorn with 2 workers, the models are loaded twice, with the total memory exceeds 8 GB maximum memory allowed in the EC2 large instance so we would have had to upgrade the instance to a xlarge instance. As a result, we had to use 2 AWS EC2 large instances with a load balancer, but we were hoping for a more clever solution.
AllenNLP is built on PyTorch, and it turns out that PyTorch can be distributed. That means PyTorch can put a tensor into shared memory, which any subprocess can access. In fact, PyTorch has forked the multiprocessing library which can be accessed via torch.multiprocessing. Most tutorials out there currently describes how PyTorch multiprocessing can be used to train models. This article is about how to take the PyTorch multiprocessing feature, integrate it with the trained model, and serving the model in an API in production.
Sharing Memory Across Cores
Once we have shared BERT, we can start attempting to load the models into shared memory. If we naively load the architecture described above with 2 workers, then each model is actually loaded into memory twice, since each worker has to load their own instance of the models. There’s no real reason that it needs to happen. Ideally, each model is loaded into shared memory once, then each incoming thread/process can access the model via the shared memory, which either worker can access.
First, we need to put each model into shared memory. Since a tensor can be placed into shared memory, then so can a model. AllenNLP’s model can be placed into memory very easily, with
model.share_memory()
In AllenNLP, Predictors are used during inference time, and they can be loaded with an archived model. When we load the model through the archive, we can put the model into shared memory with archive.model.share_memory()
before wrapping a Predictor around it. Below are two code snippets to show how to load one model, the original way, and to load it with shared memory
On a Linux machine, it is possible to view the shared memory used with the command
df -h /dev/shm
Thus, you can ensure the memory sharing is happening by watching “df -h /dev/shm” before and after the model is loaded into memory. Figure 4 and 5 below shows the difference before and after model is loaded into memory.
Docker Shared Memory
A pitfall that we fell into was not realising that docker by default allocates only 64 MB of shared memory, so loading GBs of memory into shared memory will cause the container to fail. This can actually be easily rectified with the params
--shm-size=8GB
which sets the shared memory tmpfs to 8GB.
So instead of
docker run ...
It is run with
docker run --shm-size=8g ...
After that, deploying into a docker container works smoothly.
Publish-Subscribe for Real-Time Inference
The next error occurs with the integration of gunicorn. To start a production server with Flask, we (like so many others) use gunicorn with gevent workers. We found that gunicorn’s usage of gevents concurrency for each incoming request clashes with PyTorch’s use of processes. When a gevents greenlet passes data to a child process to compute in shared memory, it was unable to return the results up to the greenlet. This was indicated by the greenlet “hanging”, waiting for the result of the process.
Different solutions were attempted to resolve this issue. One of them was the use of Celery, which manages distributed workers. The initial idea was that since celery can allocate resources to processes and workers, it should be able to distribute work into processes amongst the celery workers. Unfortunately, Celery uses billiard for multiprocessing, which is also forked from the multiprocessing standard library, and the Celery multiprocessing library doesn’t integrate well with the PyTorch forked multiprocessing library.
We couldn’t find a sufficient library to facilitate multiprocessing that is compatible with gevents greenlets. As a result, we leverage zmq to manage the publisher/subscriber communication as we separate the model inference from the Flask server. When a request comes in, the flow happens as follows (Figure 6):
- Request comes into the Flask server (client)
- The Flask server (client) sends a message to the worker (server) containing the input data, the model it wishes to run, and a unique id
- The worker (server) had all the models loaded at startup, and now runs the model specified by the message with the data included in the message
- Once the worker (server) has completed the inference, it sends a message back to the Flask server (client)
- The Flask server (client) returns the results
Two ports has to be set up:
- One for the client to communicate the input data to the server (Step 2)
- One for the server to communicate the resulting data to the client (Step 4)
The first iteration of this uses the client/server pattern. The Flask server acts as the client, and the workers acts as the server. ZMQ is used to facilitate the communication between them, via the client/server pattern. This pattern works if there is one request at a time. The problem arose when trying to send multiple requests, because according to the documentation:
socket zmq.REQ will block on send unless it has successfully received a reply back.
socket zmq.REP will block on recv unless it has received a request.
Instead, we use the PubSub pattern provided by zmq, since it fits the many-to-many communication desired. In this case, many requests could come into the Flask server, which publishes many requests to the worker. The worker high-level process receives each request and submits them to an explicit python thread. Within the thread, the data is passed into the subprocess to be computed. When the subprocess is complete, it’ll send the results via a shared dictionary back to the thread. The thread sends the results through zmq, then closes itself. This way, the worker high-level process doesn’t have to deal with the result of the computation; that is taken care of by the child thread, which manages the subprocess.
The pubsub pattern is required when thread sends the message back and forth. Since only one port is being used for the result from worker to client, it is difficult to manage all the parallel waiting threads in Flask. For example, in Figure 7 below, imagine a case with multiple calls to the API. The API calls the worker server which computes the model in parallel.
The last thing we would want to happen is the result is returned to the wrong thread, which is shown in Figure 8.
To combat that, the worker publishes the results with the unique id that was passed to it. All the threads check the id, and if it matches the one that it sent, then that thread will return the result received while the other threads continue to wait for their data.
This new architecture (Figure 9) puts the worker server in the same docker container as the gunicorn+Flask and they’re instantiated together. Since shifting to this new architecture, we were able to scale down and half our EC2 instance usage, from 2 x t2.large to 1 x t2.large.
Thread-Safe for Production
Every once in a while, we would get an error
Assertion failed: check () (bundled/zeromq/src/msg.cpp:248) (pyzmq 15.1.0)
It turns out that sockets are not thread-safe in zmq. It is a pitfall that you should be aware of when using zmq. I suggest reading that section of zeromq before implementing anything
The solution can actually be quite simple. What needs to happen is to use a forwarder which queues for PubSub. We place the forwarder between the API and the worker server; when sending from the API to the worker-server, and from the worker-server back to the API as shown in Figure 10 below.
Managing Queue Size
The load on our server fluctuates throughout the day. So far, there’s no visibility on how many subprocesses are being used, and how many are in queue. In order to combat that, we created a queue tracker.
The idea is to have a global QUEUE_SIZE variable that increases by 1 when a new message comes in, and decreases by 1 when a message returns with the thread closing. A naive method would be to set a global variable, and access it to increase or decrease queue size at entry or departure. Since each incoming message is actually an independent asynchronous thread, that would not be thread safe. Instead, we create an integer via multiprocessing Value.
Value comes with get_lock
, which temporarily locks the value so it becomes safe to modify.
In order to return the queue size back to the Flask API, we treat it like a “model” that has to be run and returned.
So when a user hits the Flask API’s /healthcheck
endpoint, the API sends a message to the worker server to fetch the QUEUE_SIZE
via queue_size()
and returns it. We set up a cronjob to send the queue information to AWS CloudWatch, so we can better monitor the service better. An example of that chart is shown below in Figure 11.
Note that the baseline for Queue Size will always be 1 using this method, because the act of fetching QUEUE_SIZE
puts it in the queue.
Do it Yourself: Code Samples
The following is code for worker_server.py
, the main code to run the models
The following code is for utils/shared_memory.py
which is used to facilitate actual putting models into subprocesses
The following code is for utils/pickle.py
, which is used to easily pickle and compress objects into binaries for sending through zmq, and vice versa
The following code is for utils/msg_client.py
, which is the interface the API uses to communicate with worker_server
. Since the msg_client
is called by the API which uses gevents, for compatibility we must use
import zmq.green as zmq
The following code is a sample of how the Flask API can be instantiated to use msg_client
to call worker_server
.
Conclusion
Initially, our Flask API and our model inference were coupled. This prevented us from spinning up more API workers, as it would be forced to load multiple instances of our models and increase our memory usage substantially. By running docker with extra flags, decoupling the Flask API and model inference, then using PubSub to communicate between them, we were able to cut EC2 costs in half. In addition, we also gained better visibility into the queue size for our worker server on AWS.
If you have any questions or improvements based on what we’ve worked on, please let me know in the comments.