FlashAttention — Techniques for Efficient Inference of LLMs (III/IV)

Andrei Apostol
MantisNLP
Published in
6 min readNov 1, 2023

Last time on this series we discussed about pruning (removing useless weights in a network) and paged attention (optimizing memory access). With the right hardware configuration, these two methods can drastically reduce runtime for an LLM as well as reduce memory usage.

In this third part of our blog series on techniques for efficient inference of large language models (LLMs), we explore FlashAttention, a method to optimize the self-attention operation itself. FlashAttention addresses the memory access bottleneck by performing computations in the faster SRAM memory through tiling and employing fused kernels.

Doing so achieves impressive speedups, making FlashAttention a choice tool in any practitioner’s toolbox. Importantly, it is an exact compression method, in the sense that the mathematics of the neural network remains unchanged (unlike pruning and quantization). This guarantees that accuracy will not degrade.

Motivation for Flash Attention

The quadratic nature of the attention operation has prompted many researchers to develop sparse or low-rank approximations. In practice, however, it has been observed that they do not translate into wall-clock speedup.

The issue, as identified in the FlashAttention paper, is due to memory access. In other words, it is the IO operations, reading and writing between tiers of memory, that is the bottleneck.

By tiers of memory, we mean levels in the memory hierarchy, illustrated below.

Fig. 1, left, from the FlashAttention paper

Typically, when people talk about GPU memory or VRAM, they are referring to the HBM (high-bandwidth memory), the second peg in the hierarchy. It is typically of a lower volume, but with significantly higher I/O than CPU memory (or DRAM).

One level above is the GPU SRAM. It is even smaller in volume than the HBM, but its I/O is an order of magnitude faster. (The numbers in the figure above are exemplified from an A100 GPU)

Perhaps surprisingly, the self-attention layer is memory-bound, rather than compute bound. This is due to the fact that it includes elementwise ops, namely dropout, masking and softmax.

Breakdown of time spent per operation. Fig. 1 right from paper.

In fact, this diagram from the paper shows that most of the time is being spent on these elementwise operations, despite the fact that the matrix multiplication is where the “heavylifting” is being performed.

Given the speed of SRAM memory, and the fact that we are bottlenecked by memory-bound operations, one may naturally want to perform those operations in SRAM.

Recall, however, that the self-attention formula implies materializing an NxN matrix when doing the QK multiplication. This is, of course, too large to fit in the SRAM memory. Thus, what the FlashAttention paper [15] proposes is tiling, which we will talk about in the following section.

Tiling

Rather than perform the matrix multiplication directly, which would result in a NxN matrix that is too large to fit into SRAM, the operation is performed in chunks.

This idea is captured in the diagram below:

Using two loops, chunks from each of the Q,K,V matrices are selected and copied into SRAM. The computation is performed using those chunks, resulting in a chunk in the output matrix (dotted square in the middle of the above figure) being populated. The process is repeated until the entire output matrix is complete. Finally, softmax is computed on the output.

Special handling of the softmax operation must be done, since softmax needs all the elements in the vector to be computed, due to the summing in the denominator:

For full details on how the softmax operation is handled, we refer the reader to the original paper or this excellent in-depth blogpost.

For now, however, this is the main idea behind FlashAttention. Tiling, and performing computations in the faster SRAM memory.

Note also that the tiling mechanism allows us to avoid materializing the NxN matrix, which essentially drops the memory requirements of self-attention from O(n²) to O(n). This is evidenced in the plot below:

FlashAttention displays a 10x memory reduction at a sequence length of 2048, and 20x at 4096. This is due to the space complexity being reduced from quadratic to linear.

There is one more issue worth mentioning here, namely kernel fusion, without which this method would be ineffective.

Kernel Fusion

Typical CUDA operations write the result to HBM once they are done, only for the next operation to read that result from memory, perform its own computation, and write it back.

Let me exemplify. Given the operation:
f(x) = sin(x)² + 12

All operations performed for a simple function.

As you can see, there are many redundant read/write operations to and from HBM. This is due to the fact that the CUDA kernels that are being used here are atomic, i.e. they perform a single op before writing back to memory. (there are 3 kernels in the example) This makes sense for cases where we need flexibility and we don’t know ahead of time what operations will be performed.

However, for cases where we have some fixed operators that we want to run, one can write fused kernels which perform multiple operations in one go before going back to memory.

Let’s say we want to perform the same operation as above, but with a fused kernel. It would look like:

You can see that we skip several read/write operations here. This is illustrated in this excellent blogpost on accelerating machine learning from first principles.

Fig. from the acceleration blogpost representing the first case, i.e. non-fused kernels
Fig. from the acceleration blogpost representing the second case, i.e. a fused kernel

One such case where this can be applied (the operations to be done are known ahead of time) is in neural networks. There are implementations of fused kernels for convolution and self-attention layers.

Coming back to the topic of FlashAttention, fusing the kernels makes it possible to run all operations in a single go while the inputs are in SRAM, without having to swap to and from HBM after every operation (i.e. matmul, mask, dropout etc).

Doing so enables the remarkable speedups obtained by FlashAttention, which we will present in the next section.

In Practice

Applying this achieves a speedup of more than 2x, tested on several GPU configurations and benchmarks, when compared to the vanilla PyTorch implementation.

Fig. from the FlashAttention paper

Of course, the particular speedup is dependent on several factors, among which:

  • Sequence length. The larger the sequence length, the more FlashAttention is expected to help
  • SRAM size. The larger your SRAM, the larger the chunks you use and therefore less swapping needs to happen
  • Head-dimension
  • Causal masking on or off
  • Whether running in forward-only or forward+backward mode

Luckily, utilizing FlashAttention is simple, since it is already implemented in many ML frameworks:

Or one can install and use the official repository.

It is also worth noting that FlashAttention 2.0 has been recently unveiled, which improves over this implementation by a further 2x in terms of speedup.

Fig. from the official repository

Wrapping Up

FlashAttention is a powerful technique that optimizes the self-attention operation in transformers by leveraging tiling and kernel fusion.

Computations are performed in the faster SRAM memory, resulting in impressive speedups of over 2x compared to the vanilla PyTorch implementation. The level of speedup depends on factors such as sequence length, SRAM size, head-dimension, causal masking, and mode of operation. FlashAttention 2.0 offers even further improvements in speedup.

Importantly, this method is exact. The output of the network remains unchanged when applying FlashAttention, which offers guarantees that the accuracy will not degrade. This is similar to Paged Attention, but different from pruning and quantization, which can affect accuracy.

Stay tuned for the final part of our blog series, where we’ll explore knowledge distillation, and the many forms that it can take.

--

--