LLM Inference Series: 3. KV caching explained

Pierre Lienhart
11 min readDec 22, 2023

--

In the previous post, we gave a high-level overview of the text generation algorithm of Transformer decoders insisting on two phases: the single-step initiation phase during which the prompt is processed, and the multi-step generation phase during which the tokens of the completions are generated one by one.

In this post, we will take a deeper look at the first challenge of LLM inference — the quadratic scaling of compute required by the attention layer (a.k.a. self-attention layer) with the total sequence length (prompt tokens and generated completion tokens). Fortunately, many of these computations are redundant across generation steps, allowing us to cache appropriate results and reduce compute requirements. This caching transforms the formerly quadratic scaling attention layer into one that scales linearly with total sequence length.

A quick refresher on the Transformer’s attention layer

Let’s start by reminding ourselves of a few facts about what happens in the multi-head attention (MHA) layer of the vanilla Transformer (Figure 1).

Figure 1 — Detailed view of a Transformer decoder layer (above) and of a two-head (self)-attention layer (below) with a input sequence of length 3

For simplicity, we assume we only process a single sequence of length t (i.e. batch size is 1):

  • At every point in the process, each token in the input sequence (prompt) is represented by a dense vector (light yellow in Figure 1).
  • The input of the attention layer is a sequence of dense vectors, one for each input token, produced by the preceding decoder block.
  • For each input vector, the attention layer produces a single dense vector of the same dimension (light blue in Figure 1).

Now, considering a single attention head:

  • First, we produce three lower-dimension dense vectors per input vector using three different projections: the query, the key and the value (leftmost light gray vectors in Figure 1). Overall, we have t query, t key and t value vectors.
  • For each query we produce an output vector equal to the linear combination of the values, the coefficients of this linear combination being the attention scores. In other words, for each query, the corresponding output vector is an attention-weighted average of the values. For a given query, the attention scores are derived from the dot product of that query with each key. By doing this, we generate a representation for each token in the sequence that includes information from the other tokens, meaning we create a contextual representation of each token.
  • In the context of auto-regressive decoding however, we cannot use all the possible values to build an output representation for a given query. Actually, when calculating the output for the query related to a specific token, we cannot use the value vectors for tokens that appear later in the sequence. This restriction is implemented using a technique called masking which essentially sets to zero the attention scores of forbidden value vectors, i.e. of forbidden tokens.

Finally, the outputs of each attention head are concatenated and transformed using a last linear transformation to yield the final output.

Quadratic scaling of attention computation

Let’s have a look a the number of FLOPs required to compute the attention scores. For one given head, for each sequence in a batch of size b of total length t (including prompt and generated completions), the attention score matrix is created by multiplying a query tensor of shape (t, d_head) with a transposed key tensor of shape (d_head, t).

How many FLOPs is spent in a single matrix multiplication? Multiplying a matrix of shape (n, p) with another matrix of size (n, m) approximately involves 2.m.n.p operations. In our case, a single-head single-sequence attention scores computations therefore approximately amounts 2.d_head.t^2 FLOPs. Overall, attention scores computations hence require 2.b.n_layers.n_head.d_head.t^2=2.b.n_layers.d_model.t^2 FLOPs. The quadratic scaling with t now appears clearly.

Looking at real numbers, in Meta’s Llama2–7B for example, n_layers=32 and d_model=4096.

Note: The multiplication of the masked attention score matrix with the value tensor requires the same amount of FLOPs as computed above.

What about the matrix multiplications involving the model weights? Using similar analysis, we can show that their computational complexity is O(b.n_layers.d_model^2.t), that is, the computations requirements scale linearly with the total sequence length t.

To understand the severity of quadratic scaling, let’s look at an example. To generate the 1,001st token, the model must perform 100x more FLOPs than to generate the 101st token. This exponential growth in compute obviously quickly becomes prohibitive. Fortunately for us and thanks to masking, a lot of computations can actually be spared between steps.

Masking induces redundant computations in the generation phase

We are now getting to the crux of the problem. Due to masking, for a given token, the output representation is generated using representations from previous tokens only. Because the previous tokens are identical across iterations, the output representation for that particular token will also be identical for all subsequent iterations and therefore implying redundant computations.

Let’s use the sequence of tokens from the previous post as an example (it has the nice feature of having one token per word). Let’s say we just generated “is ” from the “What color is the sky? The sky ” input sequence. In the past iteration, “sky ” was the last token of the input sequence, output representations associated with this token were therefore produced using the representations from all the tokens in the sequence, i.e. the value vectors for “What”, “ color”, “ is”, “ the”, “ sky”, “?”, “The ” and “sky ”.

The input sequence for the next iteration will be ”What color is the sky? The sky is ” but because of masking, from “sky ”’s perspective, it appears as though the input sequence is still “What color is the sky? The sky ”. The output representations generated for “sky ” will therefore be identical to the ones from the previous iteration.

Now an illustrated example (Figure 2) using the chart from Figure 1. The initiation step is supposed to process a input sequence of length 1. Redundantly computed elements are highlighted in light red and light purple. Light purple elements correspond to the redundantly computed keys and values.

Figure 2 — Redundant computations in the attention layer in the generation phase

Coming back to our example, for our new iteration that uses ”What color is the sky? The sky is ” as input, the only representation we have not already computed in previous steps is for the last token in the input sequence “is ”.

More concretely, what material do we need to do just that?

  • A query vector for “is “.
  • Key vectors for “What”, “ color”, “ is”, “ the”, “ sky”, “?”, “The ” “sky ” and “is ” to compute attention scores.
  • Value vectors for “What”, “ color”, “ is”, “ the”, “ sky”, “?”, “The ” “sky ” and “is ” to compute the output.

Regarding key and value vectors, they have been computed during previous iterations for all tokens but “is ”. We therefore could save (i.e. cache) and reuse the key and the value vectors from the previous iterations. This optimization is simply called KV caching. Computing an output representation for “is ” would then be as simple as:

  1. Computing a query, a key and a value for “is ”.
  2. Fetching key and value vectors for “What”, “ color”, “ is”, “ the”, “ sky”, “?”, “The ” and “sky ” from the cache and concatenating them with the key and value we just computed for “is ”
  3. Computing the attention scores using the “is ” query and all the keys.
  4. Computing the output vector for “is ” using the attention scores and all the values.

Looking at our inputs, we actually no longer need the previous tokens as long as we can use their key and value vectors. When we KV cache, actual inputs to the model are the last generated token (vs. the whole sequence) and the KV cache. Figure 3 below provides an illustrated example of this new way of running the attention layer during the generation phase.

Figure 3 — Generation step with KV caching enabled

Coming back to the two phases from the previous post:

  • The initiation phase is actually unaffected by the KV caching strategy since there are no previous steps.
  • For the decoding phase however, the situation now looks very different. We no longer use the whole sequence as input but only the last generated token (and the KV cache).

In the attention phase, the attention layer now processes all the prompt’s tokens in one go as opposed to the decoding steps where it processes only one token at a time. In the literature [1], the first set up is called batched attention (sometimes misleadingly parallel attention) while the second incremental attention.

When resorting to KV caching, the initiation phase actually computes and (pre-)fills the KV cache with the key and the values of all the input tokens and is therefore often called pre-fill phase instead. In practice, the terms pre-fill and initiation phase are used interchangeably, we will favor the former starting from now.

This new difference between the initiation and generation phases is not merely conceptual. Now, at each generation step, in addition to the weight matrices, we have to fetch an ever growing cache from memory and only to process a single token per sequence. Notice that using GPU kernels optimized for each phase brings better performance than using the same one for both the pre-fill and the decoding phase with KV cache enabled (cf. for example [2]).

KV caching enables linear attention scaling

How does attention scale now? The transposed key tensor is still of shape (t, d_head). However, the query tensor is now of shape (d_head, 1). Single-head single-sequence attention scores computation therefore requires 2.d_head.t FLOPs and overall, attention computations require 2.b.n_layers.d_model.t FLOPs. Attention now scales linearly with the total sequence length!

Are we done with quadratic scaling? Not if you discarded the cache and need to recompute it for example. Imagine you developed a chatbot application [3] and keep the cache in memory between each conversation round. Now, one client has been idle for quite some time. Since GPU memory is limited, you implemented a cache eviction policy that discards stale conversations. Unfortunately, the client resumes so you must recompute the cache for the entire history. This recomputation incurs a computational cost quadratic in the total conversation length.

The example above highlights how (KV) caching is a compromise and therefore not a free lunch. Specifically, we trade higher memory usage and data transfer for reduced computation. As we will see in upcoming posts, the memory footprint cost of caching can be substantial.

Revisiting the chatbot example, designing an efficient cache eviction policy is challenging since it requires balancing two expensive options: either consume more of scarce resources (GPU memory and bandwidth) or require quadratic amounts of computation.

An example of KV caching in real life with HuggingFace Transformers

In practice how does that look? Can we enable or disable the KV cache? Let’s take the HuggingFace Transformers library as an example. All the model classes dedicated to text generation (i.e. the XXXForCausalLM classes) implement a generate method that is used as entry point. This method accepts a lot of configuration parameters, mainly to control the token search strategy. KV caching is controlled by the use_cache boolean parameter (True by default).

Going one level deeper and looking at a model’s forward method (for example, here are the docs of LlamaForCausalLM.forward), the use_cache boolean argument is found as expected. With KV caching enabled, we have two inputs: the last generated token and the KV cache which are passed with arguments input_ids and past_key_values respectively. The new KV values (i.e. including the ones computed as part of the current iteration) are returned as part of the forward method outputs to be used in the next iteration.

How do these returned KV values look like? Let’s do some tensor counting. With KV caching enabled, the forward method returns a list of tensor pairs (one for the keys, one for the values). There are as many pairs as decoder blocks in the model (more commonly named decoder layers and noted n_layers). For each token of each sequence in the batch, there is one key/value vector of dimension d_head per attention head so each key/value tensor is therefore of shape (batch_size, seq_length, n_heads, d_head).

Looking at real numbers, in Meta’s Llama2–7B for example, n_layers=32, n_heads=32 and d_head=128. We will look in the KV cache size details in the next post but we now have a first intuition about the size it can reach.

Conclusion

Let’s summarize what we just learned. Attention scores computation scales quadratically in the total sequence length. However, due to masking in the attention computation, at each generation step, we can actually spare recomputing the keys and the values for past tokens but the last generated one. Every time we compute new keys and values we can indeed cache them into GPU memory for future reuse, hence sparing us spending FLOPs recomputing them.

The main benefit of this strategy is to make the FLOPs requirements of the (self-)attention mechanism to scale linearly rather than quadratically in the total sequence length.

As mentioned above, KV caching is a tradeoff and raises new problems we will investigate in the next posts:

  • The KV cache consumes GPU memory and can become very large. Unfortunately, GPU memory is scarce even when you have to load reasonably small LLMs. The KV cache is therefore the main technical obstacle when it comes to increasing the total sequence length (context window size) or the number of sequences you process at once, i.e. your throughput, and therefore to improving your cost efficiency.
  • KV caching greatly reduces the amount of operations we perform during a single generation step compared to the amount of data we have to move from memory: we fetch big weight matrices and an ever growing KV cache only to perform meager matrix-to-vector operations. On modern hardware, we unfortunately end up spending more time loading the data than to actually crunching numbers which obviously results in the underutilization of the compute capacity of our GPUs. In other words, we achieve low GPU utilization and therefore a low cost efficiency.

The next post will investigate the KV cache size issue. The subsequent posts will look in more details into the hardware utilization issues.

Change log:

  • 18/12/2023: First release
  • 21/02/2024: Rework the “How many operations does KV caching actually spare?” section
  • 02/03/2024: Major blog post restructuring. “How many operations does KV caching actually spare?” section has been removed.

[1]: See for example Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019)

[2]: For example, since its release 2.2.0, the reference implementation of the widely adopted Flash-Attention algorithm features a dedicated kernel for the decoding phase when KV caching is enabled (flash_attn_with_kvcache) also referred to as Flash-Decoding.

[3]: Blog post — Scaling ChatGPT: Five Real-World Engineering Challenges (Orosz, 2024)

--

--

Pierre Lienhart

GenAI solution architect @AWS - Opinions and errors are my own.