FlashAttention — one, two, three!

Najeeb Khan
5 min readSep 2, 2024

--

An Overview of Efficient Attention Mechanisms Powering LLMs

Large Language Models (LLMs) rely on the transformer architecture, where attention mechanisms are vital. Due to their computational demands, inference with attention is performed on GPUs. While software reusability is important, using existing GPU kernels like matrix multiplication and softmax for attention often sacrifices performance. In this article, we’ll explore “FlashAttention,” an efficient implementation that avoids approximations and restructures computations to align better with GPU architecture.

Standard Attention

Before exploring optimizations, let’s review the standard implementation. First we read the N x d matrices Q and K from the global memory. We compute the N x N logits matrix S using matrix multiply and write it back to the global memory.

We then read the logit matrix from the global memory and compute the softmax, writing the results back to global memory

Finally, we read the computed values matrix and the softmax outputs to compute the final N x d output matrix

These operations are shown in the sequence diagram below.

FlashAttention v1.

Implementing the standard attention using existing kernels results in a large amount of data transfer between the GPU and the slow global memory (see my previous story for an understanding of GPU architectures). FlashAttention uses a few optimization techniques to reduce IO operations to global memory.

  1. Recomputation — rather than keeping intermediate results for the backward pass — recompute them when you need them.
  2. It uses kernel fusion — combining all the sub-operations into one single kernel including matrix multiplication, softmax operation, and the final multiplication with the values matrix.
  3. Q, K and V could be quite large. FlashAttention manages kernel fusion by dividing the input matrices into blocks and operating the fused kernel on blocks of inputs only storing a few constants in global memory.
    The fused kernel first reads blocks of Q, K and compute the logits for softmax

To compute the softmax we need the sum over the whole row of the logits matrix S but we only have partial block of S. To deal with this issue, FlashAttention performs exponentiation without the denominator part of the softmax and keeping track of the denominator separately as ℓ.

and the normalization computed for each row as

Finally, the output is computed by loading a block of the values matrix V.

The figure below shows a simplified view of this process.

FlashAttention achieves a speed up of 7x over the standard PyTorch implementation for GPT-2.

FlashAttention v2.

FlashAttention already improved memory efficiency by reducing global memory accesses, but it did not achieve optimal utilization of the GPU’s computational cores, particularly on GPUs like the A100s where only around 25% of the theoretical FLOPS was achieved.

FlashAttention v2. improved GPU utilization by eliminating non-matmul operations and optimizing work partitioning over warps.

Defer non-matmul operations: rather than using the normalization per block as in FlashAttention, in v2 the normalization is applied to all the blocks at the very end. That is, first we compute the unnormalized output with maximum utilization of the GPU cores

and then re-scale the output blocks by the corresponding normalization constants

Additionally, FlashAttention v2. optimized work partitioning within each thread block to reduce communication overhead between warps and extended parallelism to multiple dimensions, such as, number of heads, sequence length, and batch size.

FlashAttention v2. achieved 50–70% of the theoretical maximum FLOPS. The improved GPU utilization resulted in FlashAttention v2. being 2x faster compared to the original FlashAttention.

FlashAttention v3.

FlashAttention v3 is a further optimization of FlashAttention v2, motivated by the advanced features of the Hopper GPU architecture, such as support for low precision matrix multiplies with Tensor Cores and asynchronous execution of multiple kernels. FlashAttention v3 performs overlapped computation of the QK multiply, softmax, and PV operations simultaneously and efficiently manages memory and computation resources.

FlashAttention v3 achieves 75% utilization of the H100 cores compared to only 35% using FlashAttention v2. It also speeds up FP16 attention by 2x and achieves 1.2 PFLOPS in FP8 precision.

Summary

In this article, we reviewed FlashAttention — an exact, efficient implementation of attention. Starting with the basic FlashAttention, we explored key optimizations: from tiling and recomputation in v1 to reducing non-matmul ops and better work partitioning in v2, and finally, to asynchronous execution and low-precision arithmetic in v3. Each version boosts speed and hardware utilization, culminating in near-peak performance on modern GPUs. The table below summarizes and compares the different versions of FlashAttention.

Further Reading

  1. Dao, Tri, et al. “FlashAttention: Fast and memory-efficient exact attention with io-awareness.” 2022.
  2. Dao, Tri. “FlashAttention-2: Faster attention with better parallelism and work partitioning.” 2023.
  3. Shah, Jay, et al. “FlashAttention-3: Fast and accurate attention with asynchrony and low-precision.” 2024.

--

--