Deploying an AlphaZero-powered Connect Four AI with GraphPipe
In my last post, I presented AZFour, a web app that lets people interact with the pre-trained Connect Four models provided by the GraphPipe project. In this post, I’ll go into detail about how I deployed the AZFour app and took advantage of the GraphPipe protocol to make optimizations along the way.
As I mentioned in my last post, AZFour is a simple Connect Four implementation with one twist: instead of relying on a heuristic AI or solver to make move suggestions, it instead uses a neural network pre-trained with the AlphaZero algorithm. For any given board position, the pretrained-network can tell us the best move to play next (the Policy) and what the likely outcome is for the game (the Value).
To achieve the highest quality play, AlphaZero typically relies on MCTS to improve upon its neural network predictions during competitive one-on-one play. However, it turns out that a well-trained neural network can evaluate game positions without MCTS and still get pretty good results. For the AZFour app, I excluded MCTS playouts to keep the deployment simple.
At the most simplistic level, the application architecture for AZFour looks like this:
Stop and think for a second about how amazingly simple this architecture is: rather than having to deal with software that understands all about the game of Connect Four, and maybe implements solvers, simulators, etc, all we need to do is pass our board positions to a standard neural network model, get the position evaluation (Policy and Value), and we are done. Beautiful!
GraphPipe published a series of pre-trained Connect Four models, which are available in the graphpipe-tf-py repo here. These models are snapshots at various stages of AlphaZero training. The higher the Generation number, the longer the network has been learning, and the better the model is at playing Connect Four.
Because the provided models are pretty small (< 900Kb), they don’t reach the peak accuracy that was described in my related article: the Generation 50 model only achieves 97.5% validation accuracy, while a larger model can get over 99%. However, 97.5% validation accuracy still yields a fairly formidable Connect Four opponent!
These models are actually small enough to run from a web browser using something like tensorflow-js, but I decided to take this opportunity to illustrate how one might solve this problem on the server-side
As a proof of concept, all we need on the backend for this app is a simple model server. We can load one of the example az4 networks with graphpipe-tf-py like this:
docker run -it --rm \
-v "$PWD:/models/" \
-p 127.0.0.1:9000:9000 \
Note that the above command uses the cpu version of graphpipe-tf to load the model, rather than one with GPU support (graphpipe-tf:gpu). Why? My lowend hobby deployment has no GPUs! Fortunately, graphpipe-tf:cpu provides MKL acceleration, which gives a nice speedup for most CPU-only deployments.
The most basic client to talk to this graphpipe server looks like this:
from graphpipe import remote
import numpy as np
board_state = np.zeros([1, 2, 6, 7]) # an empty connect-four board
result = remote.execute("http://127.0.0.1:9000", board_state) print(result) # outputs policy and value for this board position
To get an idea for the performance of my server, I made a slightly more involved python3 script (see it here). This script repeatedly makes concurrent requests to a GraphPipe server running an az4 model, sweeping across various batch sizes. Using this script against graphpipe-tf running on my test machine (a 2 core budget vm), I measured a throughput of about 100 requests/second when sending in requests of row-size 1. Not bad!
One of the drawbacks of the AZFour app from a performance standpoint, is that each AZFour user only needs one position evaluation at a time; neural networks are much more efficient when they are able to process multiple rows at the same time.
Suppose that our batch size was larger than one — how much more throughput could we get? Let’s see:
With a batch size of 1, the throughput for my test setup is about 100 rows/sec. With a batch size of 3, the throughput jumps to ~230 rows/sec. The throughput continues to climb as the batch size increases, tailing off at ~500 rows/sec by the time the batch size is 40.
So, a lot of throughput could be gained if I could find a way to batch client requests together before sending them to the inference engine.
Batching requests together should improve throughput, but where is the best place to implement it?
Batching Coupled with Inference Server
One possibility is to add batching functionality to the inference server itself, which in this case is graphpipe-tf. In fact this is how tensorflow serving implements a similar feature.
This would likely solve the problem at hand. But there are several other plausible architectures for batching worth considering.
Batching at the client
When we implemented AlphaZero at work, we optimized MCTS node expansions by implementing batching functionality in the client (more on this here):
The advantage of this approach is that row aggregation is performed before requests get sent over the network, which reduces protocol overhead and latency. If our AZFour client did need multiple evaluations in a small window of time, it would be smart to consider this approach.
Batching before a load balancer
Taking this in another direction, you can imagine another plausible architecture:
Putting the batcher before a load balancer could make sense in a scenario where inference was quite expensive. Doing this also opens up interesting possibilities for caching, monitoring, etc.
So depending on your application, you may want to put a batcher at various places in your DL pipeline.
To allow flexibility when deploying batching functionality, I put together graphpipe-batcher, a composable GraphPipe batching server. graphpipe-batcher uses the GraphPipe protocol for inputs and outputs, so it is easy to plug it into any GraphPipe data pipeline, wherever it happens to make sense.
Here is an example for how you can run graphpipe-batcher:
docker run --rm \
-p 127.0.0.1:10000:10000 \
The above docker command will launch a batching server with parameters compatible with the az4 models mentioned above. Let’s go over the params briefly:
- — target-url: this is the target address where your inference server is running
- — inputs: a comma-separated list of inputs that the target server expects. Unlike graphpipe-tf and graphpipe-onnx, there is no available model from which graphpipe-batcher can infer the inputs, so you must specify them. You can curl your graphpipe-tf/graphpipe-onnx model to see what inputs/outputs are available.
- — outputs: a comma-separated list of outputs you are requesting from the inference server
- — batch-size: the size of your batch to accumulate before forwarding to target
- — timeout: how long to wait for a batch to fill before sending to target
- — workers: how many worker threads to consume incoming requests.
- — listen: which address:port to graphpipe-batcher should bind
Batch-size, timeout, and workers need to be balanced for your workload.
Performance with Batching
For my deployment, I chose batch-size=10, workers=2, and timeout=200, values that are a bit arbitrary but should allow me to take advantage of batching in a spiky traffic scenario without delaying users for too long otherwise.
With these settings, each thread will wait up to 200 milliseconds for 10 requests; if it gets 10 requests before the timeout expires, the thread will ship the bucket immediately, and otherwise will ship the incomplete bucket after the timeout expires.
Let’s see how our performance looks now:
With server-side batching, we get a > 3X inference throughput lift — woohoo! Of course, by the time our incoming batch size reaches 10, there is no more benefit to batching, as the batcher is just proxying requests forward at that point.
In training Connect Four with AlphaZero, we saw that a high proportion of moves were repeated from game-to-game, especially for early positions. In fact, when a large number of random games are played, you can expect that ~50% of the positions will be seen more than once. This makes Connect Four a good candidate for caching.
Row-level caching is provided as part of the graphpipe-go library, which makes it easy to add caching to any server that you care to implement. This caching functionality is built-in to graphpipe-tf, graphpipe-onnx, and now graphpipe-batcher.
Although it would have been possible to enable caching at either the batching layer or the inference server layer, I ultimately enabled it in the batching component. In this configuration, cache-hits return immediately, rather than potentially having to wait for a batch to fill, which makes for a snappier user experience.
Implementing the frontend client
All together now…
Putting all this together, I settled on the following final architecture for AZFour:
Note the few additional details in the final configuration, namely the NGINX server that gates public traffic (terminating SSL, routing requests, logging, etc), and the presence of multiple batching inference groups, one for each of the Model Generations in the UI.
You can see the final product at azfour.com — Thanks for reading!