LLM Inference Series: 4. KV caching, a deeper look

Pierre Lienhart
18 min readJan 15, 2024

--

In the previous post, we introduced KV caching, a common optimization of the inference process of LLMs that make compute requirements of the (self-)attention mechanism to scale linearly rather than quadratically in the total sequence length (prompt + generated completions).

More concretely, KV caching consists to spare the recomputation of key and value tensors of past tokens at each generation step by storing (”caching”) these tensors in GPU memory as they get computed along the generation process.

KV caching is a compromise: we trade memory against compute. In this post, we will see how big the KV cache can grow, what challenges it creates and what are the most common strategies used to tackle them.

How big can the KV cache grow?

This is quite simple: for each token of each sequence in the batch, we need to store two vector tensors (one key tensor and one value tensor) of size d_head for each attention head of each attention layer. The space required by each tensor parameter depends on the precision: 4 bytes/parameter in full-precision (FP32), 2 bytes/parameter in half-precision (BF16, FP16), 1 byte/parameter for 8-bit data types (INT8, FP8), etc.

Let be b the batch size, t the total sequence length (prompt + completion), n_layers the number of decoder blocks / attention layers, n_heads the number of attention heads per attention layer, d_head the hidden dimension of the attention layer, p_a the precision. The per-token memory consumption (in bytes) of the KV cache of a multi-head attention (MHA) model is:

Notice: We remind that in MHA models, n_heads.d_head=d_model but we won’t use it to simplify the formula above.

The total size of the KV cache (in bytes) is therefore:

One of the first challenges of the KV cache appears: it grows linearly with the batch size and most importantly with the total sequence length. Since it grows with the total sequence length, the KV cache size is virtually not bounded while our GPU memory is obviously limited. Even worse, since the total sequence length cannot be known ahead of time, the KV cache memory requirements are therefore unknown making memory management particularly challenging.

Let’s look at some numbers for popular MHA models (Table 1), namely Meta’s Llama-2 [1] and OPT [2], MosaicML’s MPT [3] and BigScience’s BLOOM [4]:

Table 1 — Specifications of popular multi-head attention (MHA) models

Let’s assume the parameters are stored in half precision (FP16, BF16) and pick a smaller model (Llama-2–7B) and a larger one (BLOOM-176B). For Llama-2–7B (resp. BLOOM-176B), KV cache memory consumption amounts ~0.5MB/token (resp. ~4MB/token).

Let’s focus on Llama-2–7B. Using half precision, loading the model weights consumes ~14GB of memory, same as a caching keys and values for 28k tokens. 28k tokens could for example correspond to a batch of 56 sequences of length 512 which is not particularly extreme.

We can see from the numbers above that the KV cache memory consumption can grow very large and even exceed the amount of memory required to load the model weights for large sequences.

Now let’s compare these numbers to the memory capacity of common NVIDIA data center GPUs (Table 2):

Table 2 — Specifications of NVIDIA data center GPUs commonly used for training and/or serving LLMs

Let’s pick the rather cost-efficient A10 GPU, stick to Llama-2–7B and compute the maximum KV cache capacity. Once the model weights have been loaded, 24–2x7=10 GB remain available for the KV cache, i.e. ~20k tokens total capacity, prompts included, which obviously does not allow to serve a lot of concurrent requests when processing or generating long sequences especially.

We now understand that the KV cache prevents us from processing or generating very long sequences (i.e. obstacle long context windows) and/or from processing large batches and therefore from maximizing our hardware efficiency.

In that perspective, maximizing our processing capacity means having as much room as possible for the KV cache which can be achieved by:

  • Reducing the model weight memory footprint (weight quantization)
  • Reducing the KV cache memory footprint (cf. below)
  • Pooling memory from multiple devices by sharding our model over multiple GPUs at the cost of network communication (model parallelism) or using other kind of storage like CPU memory or disk (offloading)

Since the model weights and the ever-growing KV cache have to be loaded on each forward pass, decoding steps involves very large data transfer and as we will see in the next posts, are actually memory-bandwidth bound, i.e. we actually spend more time moving data than doing useful work, i.e. compute. In such regime, latency can only be improved by either having more memory bandwidth (i.e. better hardware) or by transferring less data. Smaller model weights and KV cache free up memory for more sequence and therefore enable to increase throughput (and/or the maximum sequence length).

In that regard, memory footprint reduction strategies are triply useful as they allow us to increase our hardware utilization and therefore cost efficiency while reducing latency and increasing throughput.

Digression - Why am I billed for my input tokens? (Table 3)

Table 3 — Sample of OpenAI rates (checked on 12/01/2024)

At this point you should get a feeling as to why you are billed for both input and output tokens. Once the input prompt has been processed, i.e. at the end of the prefill phase, we have already consumed both GPU memory (to store the key and the value tensors of each input token) and compute (to pass the prompt tokens through the model).

Let’s have a look at some real numbers. Assuming the total FLOPs count of the forward pass of a P parameters model is approximately 2.P FLOPs/token [5], processing a prompt using Llama-2-7B consumes ~0.5 MB/token of GPU memory (cf. above) and ~14 GFLOPs/token of GPU compute. For a 1000 token prompt (a bit less than a two-pager), thats ~500MB of memory and 14 TFLOPs of compute and we have not generated anything yet.

Now let’s have a look at all the ways we can reduce the memory footprint of the KV cache by taking the formula above and looking at each of its terms in turn:

What about reducing the batch size?

In most cases, we don’t want to decrease the batch size since while it helps with the KV cache memory footprint and hence with the latency, it decreases our hardware utilization and therefore our cost efficiency. In the following posts we will indeed see that on the contrary, we want to increase the batch size as much as we can.

What about reducing the dependency to the total sequence length?

One reason not to store the keys and the values for all the tokens in the sequence would be that we explicitly choose to recompute the missing ones on each iteration because it is worth spending the FLOPS instead of consuming GPU memory (for example because we are memory-bandwidth bound, which is the case during the auto-regressive phase). To the best of my knowledge, this is not something I know of in practice so we won’t dive deeper in that direction.

Another perspective would be that we could not bother storing the keys and the values for tokens the model pays no or very little attention to. This could be the case by design for models trained to attend to only part of the sequence (for example with Mistral AI’s Mistral-7B) or as part of a compromise between memory consumption and model accuracy. Let me explain.

Models like Mistral-7B [6] are trained not to pay attention to the whole sequence. Mistral-7B attention layers indeed build token representations by attending to the last (4096) neighboring tokens only. This variant of the attention mechanism is called sliding window attention (SWA) or local attention. By design, local attention guarantees that we will never store more tensor pairs per sequence in the KV cache than the window size (e.g. 4096).

Another approach consists in taking advantage of patterns in the way attention layers spread their attention over the tokens in the sequence. It is indeed known that attention modules disproportionately and consistently allocate more attention to a handful tokens in the sequence (Figure 1). By contrast, many tokens consistently contribute very little to the output so why bother storing their keys and values at all.

Figure 1 — Example of attention (heat)map from the StreamingLLM paper: A lot of attention is consistently allocated to the first token and to the last neighboring tokens (local attention)

By discarding these tokens, we de facto set the corresponding attention scores to zero and approximate the attention matrix with a sparser one. A successful approximation would minimize the approximation error and therefore the impact on model accuracy (measured using perplexity for example).

Let’s have a look at a few methods that emerged over the past few months and which are readily applicable without any retraining nor fine-tuning: the StreamingLLM framework, H2O (Heavy-Hitter Oracle), Scissorhands and FastGen. To the best of my knowledge however, none of them is yet supported by any popular LLM inference framework.

Targeting models trained with a finite length context window, the StreamingLLM framework [7] builds on the observation that the initial tokens collect a large amount of attention. The framework therefore builds a sliding window by only keeping the very first positional tokens (”sink tokens”) and the last neighboring tokens (local attention) in the cache. The StreamingLLM KV cache is therefore of fixed length with both a fixed part (typically 1 to 4 tokens) and a sliding part.

The similar H2O [8] and Scissorhands [9] methods explicitly aim at compressing the KV cache by setting a maximum number of cached tokens (budget) and by discarding tokens every time the cache budget has been reached. The H2O algorithm only discards one token at a time while Scissorhands drops as many tokens as required by a target compression ratio (e.g. 30% KV cache size reduction).

Both approaches build on the observation that influent tokens at a given step (”pivotal tokens” or “heavy hitters”) remain influent at future steps (what the Scissorhands authors name the Persistence of Importance Hypothesis). In other words, we are ensured that the discarded low-influence tokens would have remained relatively ignored at future steps so they can be safely dropped.

A key aspect of both algorithms is obviously the cache eviction policy. Scissorhands simply keeps the most recent tokens and the tokens with the highest attention scores within a history window. H2O discards the token with the lowest cumulated attention scores and therefore only keeps the tokens that consistently achieve high attention scores across iterations. Both author teams have shown that their algorithm achieve up to 80% KV cache size reduction with negligible model accuracy loss.

The FastGen method [10] (not to be confused with the unrelated DeepSpeed-FastGen) still builds on attention patterns but takes another approach by not setting a cache budget but a maximum approximation error for the attention matrix hence focusing on model accuracy preservation.

FastGen is a two-step approach: first, the model’s attention layers are profiled at the end of the prefill phase to determine the set of compression policies that allow to meet the error target. Like the other methods, it assumes that the identified attention patterns will hold in future generation steps. Compression policies include: keep special tokens, keep punctuation tokens, keep last neighboring tokens (local attention), etc. (Figure 2). If the error target is too stringent and cannot be met, FastGen falls back to regular KV caching. Then, the chosen compression policies are applied to the KV cache at each generation step.

Figure 2 — Example of set of compression policies from the FastGen paper: Special tokens (green) + Punctuation tokens (orange) + Local attention (blue). Discarded tokens are colored in gray.

Notice that contrary to other methods, FastGen builds a compression policy tailored to each prompt. The FastGen authors show that for a given KV cache compression ratio, they better preserve model accuracy than H2O and Scissorhands.

In any case, breaking the dependency to the unpredictable total sequence length is a relief since it allows to give each sequence a memory budget and therefore greatly ease memory management. Since data transfer is the main contributor to the latency, not having a KV cache that grows linearly with the sequence length can bring spectacular speedups for longer sequence lengths especially.

What about reducing the number of layers?

There is not much to gain here. Smaller models usually have less layers (Table 4) so if a smaller model performs well on your use case, simply go for it.

Table 4 — Specifications of Llama-2 models

What about reducing the number of attention heads?

Since for a given model architecture, the model size is mainly controlled by the number of layers and the number of heads, reducing the number of heads can mean opting for a smaller model (cf. Table 4).

However, if we take a closer look, we notice that we only need to reduce the number of key and value heads, the number of query heads does not impact the KV cache size. This is precisely the idea behind the multi-query attention (MQA) [11] and grouped-query attention (GQA) [12] architectures. The sole motivation of these variants of the multi-head attention (MHA) is KV cache size reduction.

MQA was introduced first in 2019. In MQA, all query heads share the same single key and value heads. In other words, all query heads compute their attention scores using the same keys and all head outputs are computed using the same values (but not the same attention scores) (Figure 3).

Figure 3 — Multi-head attention (above) vs. Multi-query attention (below) (Two attention heads)

Stripping all heads is however relatively more aggressive for larger models. For example, going from 64 heads down to 1 is comparatively a bigger cut in the model’s representation capacity than going from 32 heads down to 1. GQA solves the problem by providing a midway solution: instead of having all the query heads to share the same unique KV heads, we split them into groups of g query heads and query heads from the same group share the same unique KV heads. In other words, instead of downsizing from n_heads to 1 KV head, the number of KV heads is cut from n_heads down to 1<g<n_heads.

In that perspective, both MHA and MQA are particular cases of GQA (g=1 and g=n_heads respectively). QGA allows to navigate the model accuracy / KV cache size (which is connected to both latency and throughput) compromise more smoothly between two extreme cases, MHA and MQA.

Accounting for this new parameter g, the KV cache size formula becomes:

In practice, the MQA/GQA architecture has been notably implemented by Google Research’s PaLM [13], TII’s Falcon [14] models, Meta’s Llama-2 [1] (70B only) and Mistral AI’s Mistral-7B [7] (Table 5).

Table 5 — Model families using either MQA or GQA

What about the hidden dimension of attention heads?

Once again, there is nothing much to gain here if you are not ready to opt for another model. Depending on the model family, the head hidden dimension can be constant across model sizes (e.g. Llama-2, Falcon) so going for a smaller variant from the same family won’t help.

What about using less bytes per parameter?

Quantizing the KV cache is indeed a great way to drastically reduce its size. However, weight-only quantization algorithms like AWQ [15] or GPTQ [16] won’t help by definition. Only algorithms that quantize both weights and “activations” (i.e. anything that is not a weight) such as LLM.int8()[17] or SmoothQuant [18] would produce a quantized KV cache.

Notice that one of the intent of quantization algorithms that work on both weights and activations is to perform the compute-intensive matrix multiplications in lower precision. This gives a performance boost if compute bound like during training but as we will see in the next posts, the autoregressive phase of inference is actually memory-bandwidth bound so being able to compute faster does not bring much value. Since inference is memory-bandwidth bound, we are actually only interested in the reduction of the memory footprint since it means less data transfer.

From that perspective, quantization algorithms like LLM.int8() or SmoothQuant are a bit overkill: quantizing the cached tensors before moving them to GPU memory and dequantizing the same tensors after having fetched them from GPU memory (at the cost of additional overhead) should be enough.

A few LLM inference systems already include such a KV caching quantization feature. For example, FlexGen [19] quantizes and stores both the KV cache and the model weights in a 4-bit data format. NVIDIA TensorRT-LLM is capable of quantizing the KV cache in 8-bit data formats (INT8 or FP8). The popular vLLM framework has been supporting KV cache (FP8) quantization since version 0.3.0 as well. Since quantization is performed dynamically at each iteration, no calibration step is required.

On the importance of efficient memory management

Until now, we implicitly assumed that there is no waste in memory: all the reserved memory is used to store tokens and all the available memory can be reserved. In practice, naive memory management strategies can lead to significant part of the memory to be wasted (the PagedAttention paper [20] showed that actual effective memory utilization could be as low as 20%, i.e. 80% waste!):

  • Since the total sequence length of a request is unknown in advance, we could reserve contiguous memory chunks capable to fit the maximum sequence length. Significant part of this allocation will surely never be used and since unavailable for other requests, wasted (internal memory fragmentation).
  • Even if the sequence length is know in advance, since memory is consumed gradually but the memory chunks are reserved for the request’s lifetime, shorter requests cannot use still-unused memory chunks.
  • If we use decoding strategies that produce multiple sequences per request like beam search, the multiple candidate sequences could actually partially share their KV cache. If we do not account for this scenario, we will inevitably waste memory by storing duplicate KV entries that could have been shared.

These shortcomings are exactly what the now popular PagedAttention algorithm aims to solve. PagedAttention allocates fixed-size and relatively small memory chunks called blocks. Each block can contain a fixed number of tokens and if necessary be shared across different requests. On-demand allocation and the small block size alleviates internal memory fragmentation while same-size blocks eliminates external memory fragmentation.

Overall, PagedAttention achieves a near-zero waste in KV cache memory (less than 4% [21]). The previously wasted memory can now be used to fit more requests and therefore to increase throughput. The throughput improvement figures from PagedAttention when it came out were as spectacular as the levels of memory waste were high at the time.

PagedAttention was first implemented by the vLLM inference system but is now supported by all the major inference frameworks (e.g. HuggingFace TGI, NVIDIA TensorRT-LLM, LMDeploy TurboMind, etc.).

Another possible optimization not covered by PagedAttention is reusing the key-value cache across requests. This would apply when the prompts share a common prefix, which commonly occurs in multi-round use cases like chat and agents or when using prompt templates (Figure 4).

Figure 4 — KV cache sharing example (multi-turn chat) from the SGLang paper totaling four generation requests. Blue boxes represent shareable prompt parts.

Being able to reuse the KV cache across requests would enable significant latency (especially first token latency) and throughput (by greatly reducing the memory footprint of concurrent requests with a shared prefix) improvements.

An example of such KV cache reuse is achieved through the RadixAttention technique introduced in the LMSYS SGLang paper [22].

Instead of discarding the KV cache after finishing a generation request, the RadixAttention algorithm keeps it in GPU memory and adds a new entry to a dedicated data structure (radix tree) that maps the sequence of tokens to their KV cache tensors. When a new request comes in, the scheduler uses the radix tree for prefix matching. If there is a cache hit, the scheduler reuses the cached KV tensors to fulfill the request.

Since GPU memory is limited, cached KV tensors cannot be retained forever. The RadixAttention algorithm therefore includes an eviction policy (e.g. least recently used (LRU) eviction policy). Optimal cache reuse may not be compatible with schedules such as first-come-first-serve. RadixAttention therefore comes with a modified scheduler which prioritizes requests that match cached prefixes (cache-aware scheduling).

Note: Naming for both PagedAttention and RadixAttention is a bit misleading since contrary to what one might think, they are not optimizations of the model’s attention layer (like FlashAttention) but operate at the model server level (they help the serving application to better manage the KV cache on the host).

If we are short on GPU memory, why not “just” using multiple GPUs? or offloading to CPU memory or even to disk?

These are two different while valid approaches.

First about offloading to more abundant while slower storages (CPU memory and disk). Not all inference frameworks support this feature, let’s cite HuggingFace Accelerate, DeepSpeed-Inference and the more advanced FlexGen. Since it involves using much slower storages, offloading comes at the price of a strong latency hit so this option should obviously not be favored for latency sensitive use cases. Offloading systems are usually intended for throughput-oriented use cases like offline batch processing.

Regarding using multiple GPUs (which cannot be avoided for larger models), sharding the model over multiple devices allows to release the memory pressure by benefiting from both aggregated memory capacity and memory bandwidth.

If opting for pipeline parallelism [23], both the model and the KV cache are sharded across the layer dimension. If opting for tensor parallelism [24] (more common for inference), the KV cache is sharded across the heads dimension. Notice that MQA becomes quite inefficient in that setup: since we cannot shard a single head across multiple devices, the KV cache has to be replicated on all the devices therefore losing the benefits of MQA. An alternative for models implementing MQA is to shard the KV cache across the batch size dimension [25].

In any case, all the cases above assume a single host, we are still bounded by the storage capacity of the largest multi-GPU instance we can put our hands on. To the best of my knowledge, no inference framework support multi-host model parallelism yet. If we were able to shard both the model and the KV cache on multiple hosts, the amount of available memory and the maximum sequence length we would be capable to process become virtually unlimited. This is the problem that the Infinite-LLM paper [26] aims to solve by both introducing a new distributed attention algorithm (DistAttention) and adapting the Ray framework to build a multi-host distributed KV cache management and scheduling system (DistKV-LLM).

Summary

In this post, we learned how opting for KV caching creates additional challenges. The KV cache of multi-head attention (MHA) models indeed consumes a lot of GPU memory, in the order of ~1MB/token, and can easily grow larger than the model weights.

Given how limited GPU memory is, the KV cache memory pressure induced a lot of initiatives in different directions: novel attention architectures (MQA, GQA, SWA), cache compression strategies (H2O, Scissorhands, FastGen), efficient memory management (PagedAttention, RadixAttention), quantization and storage capacity expansion (offloading systems, single- and multi-host model parallelism).

As we will see in the following posts, reducing the KV cache size is key not only because of limited GPU memory but because the amount of data movement is actually the main contributor to the latency of each autoregressive step and therefore of the generation process as a whole.

In the next post we will look at the different kinds of bottlenecks that can affect a model’s latency and throughtput. See you there!

Change log:

  • 15/01/2024: First release
  • 18/01/2024: Add RadixAttention
  • 01/02/2024: Add vLLM support for KV cache quantization. Fix H100 and H200 numbers in Table 2.
  • 02/03/2024: Slight introduction changes

[1]: Llama 2: Open Foundation and Fine-Tuned Chat Models (Touvron et al., 2023)

[2]: OPT: Open Pre-trained Transformer Language Models (Zhang et al., 2022)

[3]: Release blog posts for: MPT-7B (May 2023) and MPT-30B (June 2023)

[4]: BLOOM: A 176B-Parameter Open-Access Multilingual Language Model (BigScience, 2023)

[5]: Scaling Laws for Neural Language Models (Kaplan et al., 2020)

[6]: Mistral 7B (Jiang et al., 2023)

[7]: Efficient Streaming Language Models with Attention Sinks (Xiao et al., 2023) + GitHub repository

[8]: H_2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models (Zhang et al., 2023) + GitHub repository

[9]: Scissorhands: Exploiting the Persistence of Importance Hypothesis for LLM KV Cache Compression at Test Time (Liu et al. 2023)

[10]: Model Tells You What to Discard: Adaptive KV Cache Compression for LLMs (Ge et al., 2023)

[11]: Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019)

[12]: GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023)

[13]: PaLM: Scaling Language Modeling with Pathways (Chowdhery et al., 2022)

[14]: The Falcon Series of Open Language Models (Almazrouei et al., 2023)

[15]: AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration (Lin et al., 2023) + GitHub repository

[16]: GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers (Frantar et al., 2022) + GitHub repository

[17]: LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale (Dettmers et al., 2022) + GitHub repository

[18]: SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models (Xiao et al., 2022) + GitHub repository

[19]: FlexGen: High-Throughput Generative Inference of Large Language Models with a Single GPU (Sheng et al., 2023) + GitHub repository

[20] Efficient Memory Management for Large Language Model Serving with PagedAttention (Kwon et al., 2023) + GitHub repository

[21] vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention (Kwon et al. 2023)

[22] Efficiently Programming Large Language Models using SGLang (Zheng et al., 2023) + Blog post

[23]: GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism (Huang et al., 2018)

[24]: Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM (Narayanan et al., 2021)

[25]: Efficiently Scaling Transformer Inference (Pope et al., 2022)

[26]: Infinite-LLM: Efficient LLM Service for Long Context with DistAttention and Distributed KVCache (Lin et al., 2024)

--

--

Pierre Lienhart

GenAI solution architect @AWS - Opinions and errors are my own.