Building fused CUDA kernels for RNNs

kevin zhang
Nov 7 · 7 min read

Introduction

One day, you might decide to implement your own RNN. PyTorch offers a convenient way to do this using the torch.nn.functional module. For example, here is how we’d implement the forward pass of GRU:

I won’t walk through the code for the forward pass, since this is not the purpose of the post. It suffices to say that the underlying mathematics for this implementation is the same as that of the PyTorch GRU module. So we would expect the performance of the two to be similar, right?

Not quite. Here are the average times it take for the forward passes from each implementation to execute on GPU:

PyTorch Library: 93 μs.
Python: 332 μs.

Several factors contribute to the performance gap. First, the operations in our forward pass implementation don’t know about each other, which means that PyTorch must execute the operations individually. Since each individual call to the kernel of an operation, which may involve launch of a CUDA kernel, has a certain amount of overhead, this overhead may become significant across many function calls. Second, the Python interpreter itself can slow down our program.

On the other hand, the PyTorch library implements its GRU forward pass as a fused kernel. This means that multiple operations in the forward pass are placed into the same kernel, which results in fewer kernel calls. The library also implements the kernel with CUDA to take advantage of the parallelism GPUs provide.

Therefore, a reliable way to optimize the speed of a custom RNN is to write its own fused CUDA kernel. PyTorch has an official tutorial on how to do this, so I won’t repeat the details. However, the tutorial assumed readers to have an understanding of GPU programming and didn’t explain the underlying logic of the RNN CUDA kernels. II’d like to fill in this knowledge gap.

CUDA Optimization

We will continue using the forward pass of GRU as our implementation example.

First, note that we can rewrite lines 12–15 from the python implementation as:

Eq.1

Two important properties of the right hand side are 1) All the constituents (i_n, i_r, i_i, h_n, h_r, h_i ) are vectors of the same dimension; and 2) all the operations are pointwise. It follows that the operations for a given index is independent from the operations of other indices. This principle can be visualized in the following diagram, where the red arrow suggests that the operations on the second index happen independently. We will call a such an index a pointwise index.

Pointwise operations for index 2

The same principle applies when batch number is greater than 1: The pointwise operations of a particular index for a particular batch is independent from all other operations:

Pointwise operations for index 2 in Batch 1

We can take advantage of this fact by parallelizing the vector operations across multiple threads on a GPU — one thread for each index computation. This is the key to why GPU programming is so effective for RNNs, since pointwise operations are prevalent in most recurrent architectures.

Let’s see how this is done. In CUDA programming, the GPU is conceptually broken down into blocks of threads. Here’s a visualization of 4095 blocks, with 255 threads in each block.

Source: An Even Easier Introduction to CUDA

The dimension of blocks can be more than 1. Here’s what having a block dimension of 2 (a grid of blocks) look like:

Source: CUDA neural network implementation

Now we need to assign the computation for each pointwise index to a particular thread. Let’s dive into the CUDA code for the GRU forward pass to see how this is done.

The code above consists of two functions. gru_cuda_forward is the entry point, it is executed on CPU. It calls the kernel function gru_cuda_forward_kernel, which is executed on GPU. Kernel methods have the __global__ key words in their declarations. I’ve left out the implementation for the kernel function, and will come back to it later.

Since gru_cuda_forward is the entry point, we will go over it first.

Lines 12–19 are essentially the same as lines 4–7 from our PyTorch implementation. The implementation for the matrix multiplication operator torch::addmm is highly optimized, so we make use of it. Lines 18–19 simply reshape our results into a format that will be easier to work with.

Lines 21–24 initialize vectors, including the new state, that the kernel function will populate later.

Line 26–27 partition the GPU into a grid of (m x n) blocks, each block containing 1024 threads. The horizontal grid dimension n equals the batch number, while the vertical grid dimension m is the minimum number of blocks, when concatenated together, required to match or surpass the state size, given that the single unit of length is the thread. This way, each pointwise index can be mapped to an unique thread:

A subtle point is that due to the way the number of blocks is calculated on line 27, the total threads in a column may be greater than the actual state size. This is acceptable since it has no harmful effects. However, the reverse does not hold— if columns contains less threads than the state size, then there will be less threads than pointwise indices, making the reverse mapping from threads to every pointwise index impossible.

On lines 29–38, the kernel function is called, and is executed on all the threads on the grid simultaneously. Inside the kernel function, the programmer has access to the Block and Thread ID of the current thread, which can be used to reverse map the thread to its pointwise index.

All the inputs to the kernel function are represented as single dimensional vectors, this makes indexing a bit convoluted. For example, in order to access the element gate[batch][column] inside the kernel, we need to do gate[batch index * state_size + column].

Besides being verbose, this expression needs additional variables such as state_size being passed into the kernel as arguments.

Fortunately, ATen provides a way to index into a tensor efficiently without having to convert to a single pointer through packed accessors. On lines 31–37, the input vectors are transformed into accessors. Let’s dissect the last input accessor on line 37:

new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>())

Here, an accessor is created for new_h, while asserting that it is a 2 dimensional tensor with scalar elements. This allows us to call
new_h[batch][column] inside the kernel.

Now, let’s take a look at the kernel function itself:

It takes in accessors as input arguments. On lines 11 and 13, the block and thread IDs are mapped to the pointwise index (which consists of the batch index and column index). Hopefully, the calculation for the column index makes sense given that blockDim.x equals to the number of threads in a block.

Line 15 ensures that computation only takes place if the column index is smaller than the state size. This check is necessary since the number of threads in a column may be greater than the state size. Inside the conditional, we simply compute the operations from equation 1 for the pointwise index, and places the answers into the output vectors at that index.

Once all the threads complete their kernel executions, the output vectors are returned back to the caller.

That’s it! We have essentially optimized our forward pass by 1) fusing multiple operations into the same kernel; and 2) parallelizing the kernel function across multiple threads on a GPU.

Performance Comparison

Our hope was to speed up our forward pass by building for it a customized fused CUDA kernel. Let’s see if that holds true.

Here’s the average time it takes for the forward pass to execute on GPU for various implementations:

PyTorch Library: 93 μs
CUDA: 178 μs
Python: 332 μs

The CUDA optimization yielded approximately 100% speed increase over the python implementation. But the interesting thing to note is that the customized kernel still lags behind the PyTorch library in performance.

I plan to investigate the cause for the difference in the coming weeks, and will follow up with another post if the investigation yields fruitful results. Meanwhile, I hope this article helps you understand how GPU programming can help speed up RNNs, both in principle and in practice.

The code used in this tutorial can be found here.

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade