LLM Inference Optimizations — Chunked Prefill and Decode-Maximal Batching

Don Moon
Byte-Sized AI
Published in
8 min readAug 28, 2024

Large Language Model (LLM) inference involves two key phases: the prefill phase, which processes the input prompt, and the decode phase, which generates output tokens one at a time in an autoregressive manner. While the prefill phase efficiently utilizes GPU resources, especially at small batch sizes, the decode phase suffers from low GPU utilization due to its token-by-token processing. Additionally, the differing durations of the prefill and decode phases can cause imbalances across micro-batches in pipeline parallelism, leading to inefficiencies and pipeline bubbles. The figure below illustrates how conventional continuous batching [3] suffers from pipeline bubbles due to workload imbalance between prefill and decode requests.

2 GPUs in pipeline parallelism setting handle 4 requests (A,B,C,D) using continuous batching [1]

Chunked Prefill — SARATHI: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefill

To address the aforementioned inefficiencies of conventional continuous batching, SARATHI utilizes “chunked-prefills,” which divide a prefill request into equal-sized chunks, along with “decode-maximal batching,” which forms a batch by combining one prefill chunk with additional decode requests. During inference, the prefill chunk fully utilizes GPU resources, while the decode requests piggyback, significantly reducing computational costs compared to processing decodes independently. This method enables the creation of multiple decode-maximal batches from a single prefill request, optimizing the handling of decode requests. Additionally, the consistent compute load of these batches mitigates imbalances across micro-batches, effectively reducing pipeline inefficiencies.

2 GPUs in pipeline parallelism setting handle 4 requests (A,B,C,D) using chunked prefills and decode-maximal batching [1]

Chunked-prefills

Chunked-prefills is a mechanism for splitting the prefill phase of large language model inference, based on two key insights. First, there is a point of diminishing returns in throughput when increasing the number of prefill tokens for a given model and GPU. As depicted in the figure below, the Llama-13B model on an A6000 GPU reaches peak throughput with 512 or more prefill tokens, while a chunk size of 256 tokens only slightly reduces peak throughput. As the hidden dimension size of the model increases, the chunk size needed to fully utilize GPU compute decreases, indicating that a compute-saturating batch can be achieved with a properly sized prefill chunk. On the other hand, decode throughput continuously increases as the batch size, or total number of decode tokens, reaches 512, around which decode throughput gets compute-bounded as well as in the prefill case.

Throughput of a single layer of LLaMA-13B on A6000 GPU[1]

Second, in practical applications, the prefill size is often large (1K–4K tokens in production), making it feasible to split the prefill request into smaller compute units. Implementing chunked-prefills requires careful attention to setting the attention masks. For example, if a 1K token input prompt is split into four chunks of 256 tokens each, the attention masks must be adjusted for each subsequent chunk to ensure that each query token can access all preceding tokens but not those that follow. This approach ensures that the chunked-prefill computation is mathematically equivalent to processing the full prefill in one go.

Attention mask is set across different chunk prefill iterations[1]; q and k represent “query” and “key” tokens. The attention mask for v (“values”) is set similarly

Chunked-prefills introduces two potential overheads. First, as chunk size decreases, the arithmetic intensity of the computation drops, potentially reducing GPU utilization and prefill efficiency. However, this issue can be mitigated by profiling the prefill throughput for various chunk sizes on a specific model-hardware setup and selecting an optimal chunk size that maximizes overall throughput.

Second, chunked-prefills introduce a slight overhead in attention computation due to repeated memory accesses of the KV cache from prior chunks. Each subsequent chunk must reload the KV pairs of previous tokens from GPU memory, which increases the attention computation time. However, this additional attention time has a minimal impact on overall prefill efficiency, as attention computation represents only a small portion of the total forward pass time, as presented in the figure below.

Per-token prefill and decode time with different batch sizes (sequence length=1024);

Note : Compute-saturating token size depends on GPU’s FLOPS. In this work, NVIDIA RTX A6000 featuring 310 TFLOPS was used. The latest GPUs like H100 showcase around 2000TFLOPS (FP16), resulting in higher compute-saturating prefill/decode token numbers.

Decode-Maximal Batching

To piggyback decode operations with a prefill, two key steps are necessary. First, it’s important to determine the maximum batch size of decodes that can be piggybacked, along with identifying the number of tokens in the prefill chunk. Second, to ensure efficient GPU utilization, the linear operation computations for both the prefill chunk and the decodes within the batch must be fused into a single operation, leveraging the GPU-saturating prefill computation to optimize the efficiency of the decodes.

The maximum decode batch size that can be piggybacked with a prefill chunk is determined by the available GPU memory (MG), the memory required by the model’s parameters per GPU (MS), and the model’s maximum supported sequence length (L). The combined total of prefill (P) and decode (D) tokens per request must not exceed this sequence length. The maximum permissible batch size (B) is calculated based on these factors, considering the memory required per key-value pair (mkv) for each token.

In SARATHI, the number of decodes is limited to B-1 because they piggyback alongside one prefill chunk.

Note: This work assumes that memory for key-value pairs, sufficient to handle the maximum sequence length across all requests, is allocated upfront. The latest inference engine, however, utilizes the PagedAttention technique [4], which dynamically allocates a new memory page for key-value pairs on demand.

In decode-maximal batching, all linear operations are fused together, but the attention computations for the prefill and decodes are handled separately. The attention for decode requests is batched together, while the attention for the prefill chunk is processed independently.

Orca (Continuous Batching) processes two decode requests (x1, x2) and two prefill requests (x3, x4). Linear operations (QKV Linear, Attention-out Linear) are fused across both prefill and decode requests, while attention computations are handled separately [2]. In contrast, SARATHI processes one prefill request and multiple decode requests in one go, fusing all linear operations, with attention for the prefill and decode requests handled independently like Orca

Decode-maximal batching combines decode tokens with prefill tokens in a single matrix multiplication, eliminating the need to separately load model weights for decoding. This method shifts the decoding process from being memory-bound to compute-bound. As shown in the table below, decode-maximal batching improves decode throughput by approximately 10x. However, this gain in throughput comes at the cost of increased decode latency — 238.4 ms compared to 49.96 ms. Therefore, it is crucial to ensure that the system’s SLO is maintained with respect to decode latencies.

Per-token prefill and decode time (in ms); For LLaMA-13B on A6000 GPU, the rows show operation times for 1) prefill-only requests of prompt size 1024 of batch size 4, 2) decode-only batch size of 4 with sequence length 1024, and c) a mixed batch of a single 1021 prefills and 3 decodes [1]

Evaluations

Single-GPU Throughput — The figure below compares SARATHI with the iteration-level scheduler Orca for LLaMa 13B on an A6000 GPU. SARATHI outperforms the conventional Orca by 1.3x. In this comparison, the maximum batch sizes that fit within the sequence length constraints are chosen: 18 for 1K, 10 for 2K, and 6 for 3K sequence lengths, respectively.

SARATHI outperforms conventional Orca by 1.3x [1]

Effect of Batch and Chunk Size — The ideal prefill chunk size (C) and batch size (B) depend on the workload’s prefill-to-decode ratio, specifically the total number of prefill tokens (P) relative to the total number of decode tokens (D). SARATHI achieves optimal performance when all decode tokens are effectively piggybacked onto prefill chunks. This occurs when the number of prefill chunks (P/C) matches the required number of decode iterations (D/(B−1)), i.e., when P : D = C/(B−1). As illustrated in the figure below, SARATHI’s highest normalized throughput is achieved at the point where this condition is met. Before reaching the optimal P:D ratio, decode-only batches exist, and beyond the optimal P:D ratio, prefill-only batches occur, both leading to suboptimal throughput. Yet, SARATHI usually outperforms Orca regardless of the ratio.

Note: Accurately predicting the workload’s P:D ratio in advance can be challenging. However, LLM service operators often have access to traces from popular workloads, such as chatbot and summarization tasks, which can be used to estimate the P:D ratio and determine the optimal prefill chunk size (C) and batch size (B).”

Varying P:D ratio (sequence length=1K, batch size=18); same experiment setup with the previous figure.[1]

Multi-GPU Throughput — The study reports results from deploying across 64 A100 GPUs over eight servers with InfiniBand connectivity, evaluating three scenarios: (1) 8-way tensor-parallel (TP) within a node and 8-way pipeline-parallel (PP) across nodes using Orca-style scheduling, (2) the same TP-PP setup with SARATHI’s chunked-prefills and decode-maximal batching, and (3) 8 parallel replicas, each with 8-way TP. The maximum batch sizes used were 27 for TP+PP (1 &2) and 11 for TP only (3). The P:D ratio was fixed at 10, with sequence lengths ranging from 1K to 4K, sampled from a Zipf distribution. The chunk size for the experiment was set to 256.

The figure below illustrates two key performance metrics for the experiment. In (a), the CDF of pipeline bubble time per request shows that SARATHI reduces median bubble time by 6.29× by creating equal-compute units of work. In (b), the time to complete 10K requests is plotted. TP-PP execution uses less memory for parameters than TP-only, allowing for a 2.45× larger batch size. However, TP-only is 1.28× faster than baseline TP-PP with Orca scheduling due to no pipeline bubbles. SARATHI’s chunked-prefills and decode-maximal batching accelerate inference time by 1.91× compared to baseline TP-PP and 1.3× compared to TP-only, making pipeline parallelism more efficient for LLM inference by minimizing pipeline bubbles.

Impact of SARATHI on pipeline bubbles (a) and request completion times (b) for GPT-3 deployed on DGX A100(s) in simulation [1]

Summary

Chunked Prefills and Maximal-decode Batching significantly boost decode and end-to-end throughput by chunking a prefill request and incorporating multiple decode requests into a single prefill chunk batch. This improvement is primarily due to (1) compute-saturating batching, which increases GPU utilization within a batch, and (2) equal-sized batching, which reduces pipeline bubbles for multi-server pipeline parallelism inference. However, these techniques may not be ideal for latency-sensitive applications. System administrators should carefully assess whether these optimizations align with the system-level objectives for decode or end-to-end latencies per request.

--

--