Techniques for Efficient Inference of LLMs (II/IV)

Andrei Apostol
MantisNLP
Published in
10 min readOct 18, 2023

--

Last time we talked about quantization, a compression technique used to reduce the bitwidth of neural networks by representing the weights in a lower precision format. This is, however, not the only technique used to compress models.

Today, we talk about pruning, a method of achieving the same goal (i.e. reduce total information in the model) by directly removing weights and/or neurons.

We will also discuss about paged attention, a systems method that allows for more efficient and flexible memory allocation.

Pruning

On the feasibility of pruning

Pruning involves the removal of weights in a model, with the purpose of reducing computational workload and storage. By removal, we mean setting a subset of weights to exactly 0 and keeping them frozen at this value throughout the rest of the training process.

Conceptually, pruning involves finding a binary mask tensor of the same dimensionality of the weights that, when applied to the model, maintains the same level of accuracy as before. Such a task can be formulated as a constrained optimization problem and, in fact, a closed-form solution has been first laid out in LeCun’s seminal Optimal Brain Damage [8] all the way back in ‘89.

The method works as follows:

  • Estimate the saliency of each individual weight, or how much it contributes to the loss
  • Remove least salient weight
  • Update remaining weights
  • Repeat until desired sparsity level is reached

You might think pruning is, then, a solved issue. Unfortunately, that’s not really the case. The issue with this method (and other methods from that same period), is that the saliency estimate and weight updates are based on the inverse of the Hessian. While this might have worked fine with the scale of models back then, today’s GPT-like models are completely impractical.

The reason for this is the time and space complexity involved with the Hessian computations. Just to showcase an example, for a ResNet50 with 25 million trainable parameters, the Hessian takes up:

(25 * 10⁶)² * 4 = 2,5 * 10¹⁵ bytes

So, storing the Hessian directly takes around 2.5 petabytes of memory. This is not to mention the computational complexity of computing it.

As such, many ulterior works have worked around this issue by using Hessian estimates. One widely used method is simply estimating the Hessian to be the identity matrix, thus pruning based on magnitude alone [9].

It works by:

  • Removing the p% weights with the lowest magnitude
  • Fine-tuning the model until accuracy is restored
  • Repeat until desired sparsity level is reached

How does pruning help?

It may be easy to see how having less parameters in a model leads to better inference time. This, however, is not so straightforward to translate into practice. After all, having sparse matrices does not mean faster multiplication if there are no optimized implementations for that.

Here, the way in which the pruning algorithm works will determine how it achieves speedup. This will make more sense in a minute! 🙂

From the Neural Magic blog

In a nutshell, unstructured pruning deals with individual connections, whereas structured pruning removes groups (i.e. neurons in linear layers, filters in conv layers, heads in attention layers).

Unstructured pruning is one where individual weights, i.e. connections are removed. This has the advantage of being very flexible in allocating sparsity in the matrix, and typically achieves higher sparsity (without degradation) than structured algorithms. The drawback, however, is the fact that this relies on sparse matrix multiplication kernels which may or may not be available, depending on the hardware.

While such kernels are typically available for CPUs, GPU machines have only recently begun to support such operations, with the introduction of sparse tensor cores in NVIDIA’s Ampere architecture [10]. This also requires a specific 2:4 sparsity pattern (i.e. 2 weights out of every block of four must be sparse), so some modifications to the algorithm itself are required.

Fig. from [10]

While certainly a viable method with performance improvements of up to 20% [11], it’s important to keep the hardware requirements in mind. This improvement of 20% (lower than the theoretical 100% speedup) is expected to increase in the future as hardware and software optimizations for sparse GEMMs improve.

The magnitude pruning introduced earlier [9] is an example of unstructured pruning.

Structured Pruning, on the other hand, aims to prune entire groups of weights. These can be neurons in a linear layer, or attention heads in an attention layer. The purpose of this is to achieve dimensionality reduction in a very direct fashion. The disadvantage, of course, is that you are restricted in how to allocate the sparsity in the matrix. From the point of view of the constrained optimization problem we spoke about earlier, it might be suboptimal to prune an entire neuron, rather than a collection of weights scattered around the network.

As such, structured pruning typically reaches accuracy degradation much faster than unstructured pruning would. The upside, however, is that no specialized hardware or methods are required. We can do dense matrix multiplication as usual, since the groups can be removed by direct deletion of rows/columns in the matrix.

Extending the magnitude pruning example to the structured case, one might weigh each neuron by the sum of its connections, corresponding to the L1 norm of a row in the weight matrix, and then remove the p% least salient neurons.

LLM-Pruner

A practical solution for pruning LLMs is the LLM-Pruner package, an implementation of the paper with the same name [12]. In the paper, the authors introduce a data-dependent importance estimator for groups of weights, which significantly outperforms the commonly used L2 norm.

The authors estimate the importance of each group of connections using an approximation of the Hessian matrix. A percentage of the groups deemed least important are removed, and the model is afterwards fine tuned for a few steps to recover accuracy (a total of 3 hours finetuning is enough according to the authors).

Pruning a LLaMa model can be done simply via calling their provided script:

python hf_prune.py --pruning_ratio 0.25 \
--block_wise \
--block_mlp_layer_start 4 --block_mlp_layer_end 30 \
--block_attention_layer_start 4 --block_attention_layer_end 30 \
--pruner_type taylor \
--test_after_train \
--device cpu --eval_device cuda \
--save_ckpt_log_name llama_prune

There are many models (LLaMa v2 included!) and pruning strategies supported in this package. Readers willing to start pruning their models are encouraged to check it out!

Paged Attention

The KV Cache

Understanding Paged Attention first requires understanding the KV cache, a common optimization trick in decoder-style transformers.

Recall the formulation of the self-attention mechanism [13]:

Before reaching this step, however, the query, key and value matrices are computed for every input token, by multiplying the input with their respective weight matrices. This is the first step of the attention mechanism. We visually showcase this below for an input context of two tokens:

Matrix calculation of the self-attention mechanism, from The Illustrated Transformer

The 4 rows of Wq, Wk, Wv visually represent the dimensionality of the embeddings, 512 in the original implementation.

For readers not fully familiar with the inner workings of the transformer architecture, or in need of a refresher, we refer this excellent blogpost from Jay Alammar.

The formula referenced above can then be computed, in order to obtain Z, the output vector of the self-attention layer:

Notice how the dimensionality of the matrices are dependent on the input size (2 rows for Q, K and V) as well as the embedding dimension (64 in the original Attention is all you need [13] paper, illustrated above as 3 columns).

Denoting sequence length by n and embedding dim as d, it follows that the time complexity of the self-attention layer is quadratic in the sequence length, i.e. O(n²d)

Due to the autoregressive nature of transformers, the sequence length here makes no distinction between prompt and generated tokens. The next two scenarios have equivalent complexity:

While one can limit the prompt length (e.g. through summarization), the generated length of the output is variable and can lead to performance degradation as more and more tokens are being generated.

For this reason, caching the KV values is a simple optimization method. Going back to our illustrated calculation of the Q,K,V values, imagine we have an X matrix of 2 rows and generate a token. When generating another token after that, the calculation will be re-done with an X matrix of 3 rows. The resulting Q, K and V values will have an extra row, as well.

But, the previous tokens have not changed, and thus the first two rows of X are the same as before. In other words, we can cache the input to the attention layer.

Moreover, note that we only need the latest row from Q, the query matrix, since we are only interested in the attention computation for the current token. We thus end up caching only the K and V values.

This is illustrated below

Figure from Speeding up the GPT — KV Cache

As such, KV caching is a method that trades off memory for speed.

PagedAttention

The KV cache, while handy, has the issue of size (taking up VRAM) and variable size, which makes efficient implementations technically difficult. PagedAttention has been proposed to solve this. From the blogpost [14]:

In vLLM, we identify that the performance of LLM serving is bottlenecked by memory. In the autoregressive decoding process, all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to generate next tokens. These cached key and value tensors are often referred to as KV cache. The KV cache is

- Large: Takes up to 1.7GB for a single sequence in LLaMA-13B.

- Dynamic: Its size depends on the sequence length, which is highly variable and unpredictable. As a result, efficiently managing the KV cache presents a significant challenge. We find that existing systems waste 60% — 80% of memory due to fragmentation and over-reservation.

As such, the authors propose the PagedAttention mechanism. Inspired from operating systems’ paging and virtual memory mechanisms, the KV cache is split into continuous, equal-sized blocks, with each block being stored in non-contiguous memory segments.

This process is demonstrated in the below animation:

Animation from the vLLM blogpost [14]

The blocks are retrieved according to their ordering and the computation is being performed with the Q vector of the current token and the K,V vectors of each token in the block.

This non-contiguous mapping to memory space allows for more flexible (and thus efficient) handling of keys and values. Physical memory blocks are allocated dynamically as more tokens are being produced. A block table is also introduced to track the mapping between logical and physical blocks.

Example generation process for a request with PagedAttention. From the blogpost [14]

This allows for efficient utilization of the memory. All allocated memory is utilized, save for the final portion of the last physical block. In practice, this allows for more efficient GPU utilization and allows the system to batch more sequences together.

The speedups obtained through this method are highly impressive. Up to 24x higher throughput than pure Huggingface, and 3.5x higher than with Huggingface Text Generation Interface (TGI).

This method is also beneficial for parallel sampling algorithms, such as beam search. While normally such methods are difficult to scale, using PagedAttention can cut memory usage by 55% (and therefore increase throughput by over 2x).

For full details on the method, we refer the reader to the associated blogpost [14]. An official paper will also be made available shortly.

vLLM implementation of PagedAttention

The PagedAttention mechanism sits at the heart of the vLLM inference engine. It can be installed via:

pip install vllm

Running inference locally can be done easily by:

from vllm import LLM

prompts = ["Hello, my name is", "The capital of France is"] # Sample prompts.
llm = LLM(model="lmsys/vicuna-7b-v1.3") # Create an LLM.
outputs = llm.generate(prompts) # Generate texts from the prompts.It is also possible to spin up a server via a provided entrypoint directly from the command line. We refer the reader to the quickstart guide.

It is also possible to spin up a server via a provided entrypoint directly from the command line. We refer the reader to the quickstart guide.

Wrapping Up

In the second part of our blog series on efficient inference techniques for large language models (LLMs), we explored two crucial optimization methods.

Pruning involves removing weights and/or neurons to reduce the model’s information while maintaining accuracy. Unstructured pruning allows for higher sparsity but requires specialized hardware, while structured pruning removes entire groups of weights for dimensionality reduction. The LLM-Pruner package offers a practical solution with an improved importance estimator.

Paged attention, inspired by operating systems’ mechanisms, efficiently manages the KV cache during autoregressive decoding, leading to impressive speedups and memory utilization.

Stay tuned for Part III, where we’ll delve into FlashAttention, another systems method that offers impressive gains.

--

--