TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Understanding Flash Attention: Writing the Algorithm from Scratch in Triton

Find out how Flash Attention works. Afterward, we’ll refine our understanding by writing a GPU kernel of the algorithm in Triton.

Alex Dremov
TDS Archive
Published in
8 min readJan 15, 2025

--

Read for free at alexdremov.me

Flash Attention is a revolutionary technique that dramatically accelerates the attention mechanism in transformer-based models, delivering processing speeds many times faster than naive methods. By cleverly tiling data and minimizing memory transfers, it tackles the notorious GPU memory bottleneck that large language models often struggle with.

In this post, we’ll dive into how Flash Attention leverages efficient I/O-awareness to reduce overhead, then take it a step further by crafting a block-sparse attention kernel in Triton.

💥 I will provide a simple explanation of how Flash Attention works. We will then implement the explained algorithm in Triton!

What is Attention?

The attention mechanism (or scaled dot-product attention) is a core element of transformer models, which is a leading architecture for solving the problem of language modeling. All popular models, like GPT, LLaMA, and BERT, rely on attention.

The formula is pretty simple:

The rest is history.

Even though the formula looks simple, its computation involves multiplications of large tensors and a lot of data movement. Considering that this is a core part of the transformer architecture, optimizing the algorithm greatly improves the performance of the model in general.

In the naive implementation, attention requires O(n²) additional memory and O(n²) compute time complexity, where n is the sequence length. That’s a lot!

Flash Attention

Core Idea

The main idea of Flash attention can be summarized in a simple quote from the original paper:

We argue that a missing principle is making attention algorithms IO-aware — accounting for reads and writes between levels of GPU memory.

That is, modern GPUs have several types of memory:

  • SRAM — fast, on-chip, small

--

--

TDS Archive
TDS Archive

Published in TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Alex Dremov
Alex Dremov

No responses yet