What is Flash Attention?

Improved Attention Mechanism for Language Models

Mehul Gupta
Data Science in your pocket
4 min readJul 12, 2024

--

“Attention is All You Need” has been a breakthrough paper and it can easily be considered the paper of the decade if not century. Why? because it eventually lead the world to ChatGPT and Generative AI.

How is Attention related to GenAI?

The Attention mechanism used in Transformers model discussed in the paper has been the core of any LLM. It is the attention mechanism that helps LLMs understand the context of the prompt ans answer accordingly.

My debut book : LangChain in your Pocket is out now !!

To read more about Attention, check this:

A great idea, but, still it has certain major limitations, especially when it comes to time and space complexities:

  1. Quadratic Memory Requirement: The standard attention mechanism has a memory requirement that scales quadratically with the sequence length, which limits its applicability to long sequences.
  2. Computational Complexity: The attention computation itself has a time complexity that scales quadratically with the sequence length, leading to slower processing times, especially for large models.
  3. Memory Inefficiency: Traditional attention mechanisms require substantial memory to store the relationships between all parts of the input data, leading to high memory usage.
  4. Numerical Instability: Attention computations can suffer from numerical stability issues, especially when working with large sequences and models, leading to inaccurate results.

What is Numerical Instability?

Numerical stability is a desirable property of numerical algorithms where small perturbations in the input data or rounding errors do not lead to large deviations in the final output. In other words, numerical stability ensures that the algorithm is robust and does not magnify errors during the computation.

In easy words,

Imagine you’re trying to solve a math problem, but you’re using a calculator that sometimes makes small mistakes. If the problem is “stable,” these small mistakes won’t make a big difference in the final answer. But if the problem is “unstable,” even small mistakes can make the final answer completely wrong.

FlashAttention optimizes the attention mechanism in transformers by leveraging advanced memory and computation techniques to improve efficiency, hence it improves the time and space complexity without hampering the performance of the model .

How Flash Attention Works

FlashAttention improves Attention’s time and space complexity by bringing in the below changes

1. Tiling: Dividing the large attention matrix into smaller, more manageable tiles. This reduces the memory footprint by processing one tile at a time instead of the whole matrix.

2. Efficient Memory Access: FlashAttention Optimizes the way data is accessed in memory, minimizing cache misses and improving data locality, speeding up time complexity. It leverages the GPU memory hierarchy, using the faster on-chip SRAM memory instead of the larger but slower high-bandwidth memory (HBM).

3. Parallelization: Uses parallel computing techniques to perform multiple calculations simultaneously on tiled matrix, reducing the computation time.

4. Numerical Stability: Implements techniques to maintain numerical stability during computations, such as careful scaling and normalization. This ensures accurate results even with large sequences and models.

Example

Let’s consider a sequence of 4 tokens: [A, B, C, D].

Standard Attention

  1. Compute Attention Scores:
  • For each pair of tokens, compute the attention score (following the dreadful QKV matrices).
  • Results in a 4x4 matrix.
|    | A  | B  | C  | D  |
|----|----|----|----|----|
| A | 1 | 2 | 3 | 4 |
| B | 2 | 1 | 3 | 4 |
| C | 3 | 2 | 1 | 4 |
| D | 4 | 2 | 3 | 1 |

Apply Softmax and Weighting:

  • Apply softmax to the scores to get attention weights.
  • Use these weights to compute the weighted sum of values.

FlashAttention

  1. Tiling:
  • Divide the 4x4 matrix into smaller tiles. For simplicity, let’s use 2x2 tiles.
Tile 1:      Tile 2:     
| 1 | 2 | | 3 | 4 |
| 2 | 1 | | 3 | 4 |

Tile 3: Tile 4:
| 3 | 2 | | 1 | 4 |
| 4 | 2 | | 3 | 1 |

2. Efficient Memory Access and Parallelization:

  • Process each tile individually using optimized memory access patterns.
  • Perform computations in parallel across different tiles.

3. Numerical Stability:

  • Apply softmax within each tile, ensuring numerical stability.
  • Aggregate results from each tile to form the final attention weights.

4. Combine Results:

  • Combine the weighted sums from each tile to produce the final output.

In essence, FlashAttention makes the attention mechanism more efficient and scalable, enabling better performance for large-scale transformer models. Recently, some of the SOTA LLM models released on HuggingFace have started using Flash Attention that you can checkout on official HuggingFace website.

Until next time

--

--