Make your model faster on CPU using ONNX
How to speed up a Tensorflow model by 200%?
Neural networks are very powerful, but also infamous for the large amount of computing power they need. In general, more parameters in a model increase not only the time needed for training, but also result in a longer inference time. Half a second might not sound like much, but if you are running the model hundreds of times a day, it suddenly becomes quite something. Not to mention the networks that spend even longer on a single inference. One solution would be to simply throw a GPU at the problem. However, that is not always possible or the most cost-effective method. In this short article, I will show you how to double CPU inference speed by simply switching runtimes, using ONNX and the ONNX runtime. This will result in lower costs, lower response times, and a more portable algorithm.
All steps in this article are extensively documented in a Jupyter notebook.
What is ONNX? ONNX runtime
Many neural networks are developed using the popular library TensorFlow. However, as the title suggests, the speed-up will come from using ONNX. But what exactly is ONNX? ONNX stands for “Open Neural Network Exchange“ and is basically an open representation format for machine learning algorithms. It allows for portability, in other words, an ONNX model can run everywhere. You can simply import and export ONNX models in popular tools like PyTorch and TensorFlow for example. This is great on its own, but the added benefit is that you can choose the runtime. This means that you can optimize the model better for your purposes based on the model’s needs and the way it is run. We are not going into complicated optimizations in this article, however, we are going to make a very simple optimization. We will switch the standard TensorFlow library, with all of its unnecessary bloats, with something more optimized. We will use the ONNX runtime, an optimized runtime for ONNX algorithms easily used from Python. Is TensorFlow really bloated? Well, it can be for just inference purposes, something the ONNX runtime excels at because it does only that, inference.
Preparing the TensorFlow model
It is quite easy to convert a network in the `SavedModel` format from TensorFlow to ONNX. You can use the handy python program tf2onnx to do this. It does all the hard work for you. As long as you do not have a very exotic neural network, the following line will probably work:
python3 -m tf2onnx.convert --saved-model model --opset 13 --output model.onnx
Switching runtimes
Assuming you are running something similar to this for inference using TensorFlow:
from tensorflow.keras.models import load_model
model = load_model(“model”)
out = model.predict(x)
We now have to use the ONNXruntime with our converted network instead:
import onnxruntime as rt
sess = rt.InferenceSession("model.onnx")
input_name = sess.get_inputs()[0].name
out = sess.run(None, {self.input_name: x})[0]
It does not get more simple than this. The only real difference is syntax related, and what you might notice is that the ONNX runtime is a bit more sensitive to input names, but these are also stored in the ONNX format, so we can easily look them up with the “`get_inputs”` method.
Results
Now some proof that this actually works. The easiest way is to simply run the two different scripts a few times with a stopwatch and see which one takes longer. I did something a bit more accurate, I made two similar deployments on our deployment platform UbiOps. UbiOps is a platform that allows everyone to quickly run a piece of R or Python code into a professional production environment.
I then sent 100 requests to both of them and looked at the average time spent on computing one request. Here are the results:
Figure 1:
Figure 2:
Model performance
As you can see, this roughly doubles the performance of the model with minimal effort. Try for yourself by downloading the Jupyter notebook with all the steps.
If you want to know more about UbiOps take a look at our product page and our Tutorial (a repository of example Jupyter notebooks.)