How FlashAttention enables scaling up training and inference with zero cost

Yiftach Beer
Theator Tech
5 min readMar 3, 2024

--

Nowadays, transformers are everywhere — from natural language processing to computer vision to bioinformatics and more. As they are so prominent, we want to make them better — run faster and with lower memory requirements, so that we can fit more work with the same setup.

When looking for ways to scale up transformers, deep learning engineers are faced with endless suggestions of things to try — inference engines, new shiny optimizers, approximate attention mechanisms, and more. While the list of things to try is endless, they often only give a small boost or come at the cost of accuracy.

FlashAttention (introduced in Dao et al. 2022) may seem like yet another “speed up” method. However, it both speeds up models and dramatically lowers memory requirements, and most importantly, produces exactly the same result. All speed-up comes from a hardware-aware implementation and with no cost.

When I applied this feature in my domain, which is video analysis in long videos, it took a setup that required ~70G GPU memory to ~20G, enabling an increase of batch size x4 while making training x1.5 faster.

If you are not familiar with FlashAttention, not convinced you should use it or want to understand how it works before you make the switch — read on.

Background

It is well known that the memory requirements of Attention mechanisms scale quadratically in the number of input tokens. New variants are constantly created to approximate the same result with a cheaper computation. Some of the inefficiencies of the original computation, however, are quite avoidable. For example, consider the following implementation from the timm library, before FlashAttention was introduced:

In line 12, we take the softmax of the attention tensor. As long as the computation takes place, the memory of the previous tensor cannot be freed, requiring, at peak, double the amount of memory. To avoid this effect, we could compute it row by row, but that would greatly increase the time it takes.

Additionally, GPU memory is made of a small-but-fast memory unit called Static Random Access Memory (SRAM), and a larger-but-slower memory unit called High Bandwidth Memory (HBM). Both of the aforementioned tensors cannot fit wholly in SRAM, and most of the time is spent not on the computations themselves, but on moving them from HBM to SRAM to perform computations and back. This is inefficient — ideally, we would move them once, perform all computations and then move them back. This is called Operator Fusion.

Copyrights for all images in this post belong to the paper’s authors

Enter FlashAttention

FlashAttention is a hardware-aware implementation of Attention. That means it produces exactly the same results as the vanilla implementation, but does so in a way that reduces overheads such as the ones above.

To achieve those savings, the authors use a couple of methods:

Tiling
Instead of repeatedly copying memory from HBM to SRAM and back, FlashAttention divides the memory into blocks and fully computes the attention equation on each. This means that we only have to move this memory once from HBM, instead of over and over from HBM to SRAM and back at each step of the computation. While the matrix multiplication can naturally factorize into blocks, the softmax is slightly more tricky and requires storing additional summary statistics for the normalization. Computing the final result in blocks means we never have to create large NxN matrices, removing the need for scaling memory quadratically and saving tons of memory.

Recomputation
Some of the values computed during the forward pass are needed again for the backward computation. The straightforward solution would be to store them during the forward pass and fetch them again. However, it turns out that discarding them at the end of the forward pass and then re-calculating on-demand during the backward pass is cheaper. The memory transfer of the large attention tensor from HBM is much slower than computation on the values that are already in the SRAM. If this sounds familiar, this is exactly what you’d see in Activation Checkpointing, which you can read more about in this post.

Additionally, the fact that this whole building block is wrapped into a single CUDA kernel reduces the overhead. The time spent running Python/Torch code is minimized and most of the time is spent running a single kernel on the GPU.

In the end, this is exposed to the user as a Python module/function.

Putting FlashAttention to Work

If you’re using newer versions of PyTorch you’re in luck — since PyTorch 2.0, FlashAttention is already incorporated into the scaled_dot_product_attention function and will be chosen when applicable, and since PyTorch 2.2, Flash Attention 2 is built in too. So if you have not yet recently upgraded your PyTorch installation, this might be a good reason to do it. Other libraries such as timm and transformers are gradually adding support for it for all models too.

If you need to use older versions of PyTorch or just want to work with the source, use the official implementation. It has a couple of other types of attention types too and is worth checking out.

The gains we will see in practice depend on many factors such as sequence length, model architecture and GPU type. Here for example are some of the speed improvements reported by the authors on an NVIDIA A100:

As we can see, using FlashAttention greatly improves both the speed and memory usage, hardly requiring any changes to the code. Importantly, such a large reduction in the memory footprint allows using cheaper GPUs, which might be more readily available, to train the same models.

Final Thoughts

While it is very encouraging that the implementation is a free improvement, it does come with some obvious yet important caveats.

First, you can’t easily introduce modifications to the implementation yourself, such as other modes of attention.

Second, as the implementation is a CUDA kernel and fairly new, you might run into bugs, especially when using new hardware like we did. An implementation needs to be written and compiled specifically for your hardware — which might limit you if you need to use old or exotic hardware.

For additional reading, see the original paper. You might also be interested in further work such as FlashAttention2 and PagedAttention.

--

--

Yiftach Beer
Theator Tech

AI Research Engineer. I like building things involving software, electronics and math.