Debug ONNX GPU Performance

What to do when your model is slower than expected

David Mezzetti
NeuML
6 min readNov 30, 2022

--

Photo by Thomas Foster on Unsplash

The steps described in this article are also documented in this GitHub issue

The ONNX Runtime is a cross-platform inference and training machine-learning accelerator. It provides a single, standardized format for executing machine learning models.

To give an idea of the breadth of support, the image below shows all the current build platforms.

Source: https://github.com/microsoft/onnxruntime

ONNX has a lot of promise and is a great project. It widens the ways a model can be executed. But it’s not always straightforward.

This article will review a case of exporting a PyTorch model to ONNX and what was done to improve GPU performance.

Background

txtai recently added support for text-to-speech models. The models chosen were PyTorch models in ESPnet exported to ONNX using espnet_onnx. These ONNX models are available on the Hugging Face Hub.

More on this can be read at the link below.

The Problem

The following two code sections show a minimal example to run inference using ESPnet directly (PyTorch) and running the same model through ONNX.

First the code using ESPnet directly and PyTorch.

import time

from espnet2.bin.tts_inference import Text2Speech

model = "espnet/kan-bayashi_ljspeech_vits"
model = Text2Speech.from_pretrained(model, device="cuda")

def run(text):
start = time.time()
output = model(text)
speech = output["wav"].cpu().numpy()
print("Time:", time.time() - start)

Running the code above after warmup (the 2nd run) results in the following:

>>> run("warmup")
>>> run("Text to speech models have recently made great strides in quality")
Time: 0.16863441467285156

Next the model is run with ONNX.

import time

import onnxruntime
import yaml

from ttstokenizer import TTSTokenizer

with open("ljspeech-vits-onnx/config.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)

tokenizer = TTSTokenizer(config["token"]["list"])

model = onnxruntime.InferenceSession(
"ljspeech-vits-onnx/model.onnx",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)

def run(text):
print(text)

# Tokenize text to phoneme token ids
inputs = tokenizer(text)

start = time.time()
outputs = model.run(None, {"text": inputs})
speech = outputs[0]
print("Time:", time.time() - start)

Running this code above after warmup (the 2nd run) results in the following:

>>> run("warmup")
>>> run("Text to speech models have recently made great strides in quality")
Time: 1.134232521057129

0.17s in PyTorch vs 1.13s in ONNX, both with GPUs enabled. PyTorch is almost 7x faster.

Let’s see if we can get to the bottom of this.

Attempt #1 — IO Binding

After doing a couple web searches for PyTorch vs ONNX slow the most common thing coming up was related to CPU to GPU data transfer. While the inputs to this model are small, the output wav data can be 100KB+ easily. So this could be it. More on passing data to ONNX is available in the API guide.

We’ll modify our run method as follows and re-run.

def run(text):
print(text)

# Tokenize text to phoneme token ids
inputs = tokenizer(text)

io_binding = model.io_binding()
io_binding.bind_cpu_input("text", inputs)
io_binding.bind_output("wav", "cuda")

start = time.time()

# Run model
model.run_with_iobinding(io_binding)
outputs = io_binding.copy_outputs_to_cpu()
print("Time:", time.time() - start)
>>> run("warmup")
>>> run("Text to speech models have recently made great strides in quality")
Time: 1.1370353698730469

About the same run time. Doesn’t seem like, at least in this case, the issue is related to data transfer.

Enable profiling

We’ll make the following change to our ONNX program to enable profiling and re-run to get more info on the slow down.

opts = onnxruntime.SessionOptions()
opts.enable_profiling = True

model = onnxruntime.InferenceSession(
"ljspeech-vits-onnx/model.onnx",
opts,
providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)

A JSON file with the format onnxruntime_profile_{DATE}.json should be in the runtime directory. This file has a long list of JSON elements with the execution duration of each node.

Let’s run the following to extract and find the slowest nodes.

$ jq . onnxruntime_profile.json | grep dur | cut -d ":" -f2 | sort --numeric

We’ll take the largest times and look for them in the onnxruntime_profile file.

The first couple times were related to the full model run and session initialization. But after that, this one stood out.

    "cat": "Node",
"pid": 1027615,
"tid": 1027615,
"dur": 313964,
"ts": 4425685,
"ph": "X",
"name": "/w_1/Conv_kernel_time",

And there were more like this. This is a convolutional layer.

Attempt #2 — CUDA Settings

Let’s take a look at the documentation and see what CUDA options are available.

https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cudnn_conv_algo_search

cudnn_conv_algo_search is the option that stood out the most. The default value of EXHAUSTIVE with the mention of expensive also seemed relevant.

Let’s try changing this setting and re-running.

opts = onnxruntime.SessionOptions()
opts.enable_profiling = True

model = onnxruntime.InferenceSession(
"ljspeech-vits-onnx/model.onnx",
opts,
providers=[
("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),
"CPUExecutionProvider"
]
)
>>> run("warmup")
>>> run("Text to speech models have recently made great strides in quality")
Time: 0.1624901294708252

Now the runtime is the same if not slightly better than PyTorch!

Here is the full code used to achieve these results.

import time

import onnxruntime
import yaml

from ttstokenizer import TTSTokenizer

with open("ljspeech-vits-onnx/config.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)

tokenizer = TTSTokenizer(config["token"]["list"])

model = onnxruntime.InferenceSession(
"ljspeech-vits-onnx/model.onnx",
providers=[
("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),
"CPUExecutionProvider"
]
)

def run(text):
print(text)

# Tokenize text to phoneme token ids
inputs = tokenizer(text)

start = time.time()
outputs = model.run(None, {"text": inputs})
speech = outputs[0]
print("Time:", time.time() - start)

Making sense of the results

Let’s take a look at PyTorch and see if there are any hints on what could account for the difference with the default settings.

The file Conv_v7.cpp in PyTorch has the following logic that appears relevant.

  // Code reduced to these statements for clarity
static constexpr auto DEFAULT_ALGO =
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;

if (!benchmark) {
AT_CUDNN_CHECK_WITH_SHAPES(cudnnGetConvolutionForwardAlgorithm_v7(
)
} else {
AT_CUDNN_CHECK_WITH_SHAPES(cudnnFindConvolutionForwardAlgorithmEx(
)
}

Opening up a new Python REPL and running the following shows.

>>> import torch
>>> torch.backends.cudnn.deterministic
False
>>> torch.backends.cudnn.benchmark
False

Which would mean cudnnGetConvolutionForwardAlgorithm_v7 is the default function. Even with benchmark set to True, PyTorch will cache the result if the inputs shape is static. In the case of this model, the inputs are dynamic. See this discussion thread for more.

Let’s verify this using the original ESPnet PyTorch example.

>>> run("warmup")
>>> run("Text to speech models have recently made great strides in quality")
Time: 0.17191290855407715
>>>
>>> import torch
>>> torch.backends.cudnn.benchmark
False
>>> torch.backends.cudnn.benchmark = True
>>> torch.backends.cudnn.benchmark
True
>>> run("warmup")
>>> run("Text to speech models have recently made great strides in quality")
Time: 1.0886101722717285

Setting torch.backends.cudnn.benchmark to True matches the original ONNX result.

https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#cudnn_conv_algo_search

Looking back at the ONNX documentation above it states:

HEURISTIC (1)

lightweight heuristic based search using cudnnGetConvolutionForwardAlgorithm_v7

It appears that unless otherwise set, that PyTorch is not doing an exhaustive search. The PyTorch default would be to run what ONNX calls HEURISTIC. In testing both HEURISTIC and DEFAULT, the same performance was obtained.

Wrapping up

This article covered how to debug ONNX model performance, specifically when running on the GPU. The most common area to look is at CPU to GPU data transfer and making that more efficient with IO Bindings. But in this case, changing a CUDA setting was the solution.

It’s possible this is an isolated case with this model or the hardware the tests were run on. But it’s at least something to test if this problem is encountered in the future!

--

--

David Mezzetti
NeuML

Founder/CEO at NeuML. Building easy-to-use semantic search and workflow applications with txtai.