Cache-Augmented Generation (CAG) from Scratch

Sabaybiometzger
9 min readJan 14, 2025

--

A step-by-step tutorial on implementing Cache-Augmented Generation in Python. In this article, we’ll briefly explore the theory behind CAG to understand its core concepts. According to the original paper, CAG offers an alternative paradigm to Retrieval-Augmented Generation (RAG). The primary goal of both RAG and CAG is to enhance language models by integrating external knowledge.

Basic concept

In short, RAG’s strategy involves storing external knowledge encoded as vectors in a vector database. Before querying the LLM, the input query is also encoded into a vector, and the knowledge vectors with the highest similarity to the query vector are retrieved. This retrieved information is then added to the prompt given to the LLM to generate a response. This approach is powerful and theoretically scalable to very large knowledge sources. However, it introduces potential errors in document selection, which depends on how documents are chunked and the quality of the embedding model used to create the vector database.

CAG offers a simpler approach. If your external knowledge base is of a manageable size, CAG involves directly including the entire knowledge base within the prompt along with the query. The LLM can then process both the query and the knowledge base to generate a response. This strategy eliminates the need for a vector database and similarity calculations. CAG benefits from the recent advancements in LLMs, such as models like Llama, Mixtral, and Gemma, which demonstrate improved performance and efficiency with larger context windows.

However, a naive implementation of CAG, where the entire knowledge base is included in every prompt, would lead to very slow inference times. This is because LLMs typically generate one token at a time, and each prediction depends on the entire preceding context. Here’s where the key innovation of CAG comes in: by preloading the knowledge base into the model’s context and using a dynamic caching strategy (specifically Key-Value caching), we can avoid repeatedly processing the knowledge base for each new query. The model effectively “remembers” the processed knowledge, allowing it to focus only on the query during inference.

Here an Overview of the basic concepts

Retrieval-Augmented Generation (RAG) in comparison to Cache-Augmented Generation (CAG)

Code Tutorial: Implementing CAG

This section dives into the practical implementation of the CAG concept. We’ll base our code on the work of hhhuang, particularly their kvache.py script, available on their GitHub repository. The core ideas are from the original CAG research paper.

The code will utilize the same LLM model as the research paper: “Llama-3.1B-Instruct.” and were have been successfully tested in a Kaggle notebook environment. This ensures the code functions can be easily adapted to your own project.

We’ll move on from setting up the environment and delve into the details of the kvache.py script. This script likely focuses on creating and utilizing the key-value cache to implement CAG functionality within the chosen LLM.

Before delving into the code itself, let’s ensure we have the necessary libraries installed:

#!pip install -U bitsandbytes 

import torch
from transformers import (
AutoTokenizer,
BitsAndBytesConfig,
AutoModelForCausalLM)

import bitsandbytes as bnb
from transformers.cache_utils import DynamicCache

used versions

transformers  : 4.44.2
bitsandbytes : 0.45.0
torch : 2.4.1+cu121

### GPU
Kaggle GPU T4x2
CUDA Version : 12.6
Driver Version : 560.35.03

Logging in to Hugging Face

It is necessary to use the LLama-3.1-model

  1. Create an Account: Visit https://huggingface.co/ and sign up for a free account.
  2. Generate an Access Token: Go to your profile settings (top right corner) -> Access Tokens -> Create a new token. This token grants access to Hugging Face features like uploading fine-tuned models.
from huggingface_hub import notebook_login
notebook_login()

Prepare Knowledge

For this demonstration, we’ll give the model some background information to work with. This information consists of simulated clinical reports and incidents related to medical devices. It’s important to emphasize that all of this data is completely synthetic and not based on real events.

This knowledge is very specific to the medical device domain. A standard, pre-trained LLM wouldn’t be able to answer questions about these reports without being given this context first. In other words, the model needs this specific knowledge to understand and respond to questions about the reports

knowledge = """
Incident 1: Glucose Meter Malfunction Leads to Hyperglycemia

Patient: John Miller, 62 years old
Device: GlucoFast Ultra glucose meter, manufactured by MediTech Solutions Inc.
Incident: Mr. Miller, a known diabetic, used his GlucoFast Ultra meter to check his blood glucose level before dinner.
The meter displayed a reading of 90 mg/dL, which was within his target range.
However, shortly after eating, he began experiencing symptoms of hyperglycemia, including excessive thirst, frequent urination, and blurred vision.
A subsequent check with a hospital-grade blood glucose analyzer revealed a blood glucose level of 250 mg/dL.
Investigation: It was determined that the GlucoFast Ultra meter was providing falsely low readings, likely due to a faulty batch of test strips. MediTech Solutions Inc.
issued a recall for the affected lot of test strips.
Outcome: Mr. Miller was treated for hyperglycemia and recovered fully.

Incident 2: Heart Pump Failure During Surgery

Patient: Jane Doe, 58 years old
Device: CardioAssist Ventricular Assist Device (VAD), manufactured by HeartLife Technologies.
Incident: Ms. Doe was undergoing a heart transplant surgery. During the procedure,
the CardioAssist VAD, which was supporting her circulation, suddenly malfunctioned, causing a critical drop in blood pressure.
Investigation: The investigation revealed a software glitch in the VAD's control system, causing it to unexpectedly shut down.
HeartLife Technologies issued a software update to address the issue.
Outcome: The surgical team was able to stabilize Ms. Doe and complete the transplant successfully.
However, the incident caused a delay in the procedure and increased the risk of complications.

Incident 3: X-Ray Machine Overexposure

Patient: Robert Smith, 45 years old
Device: XR-5000 Digital X-Ray System, manufactured by Imaging Dynamics Corp.
Incident: Mr. Smith was undergoing a routine chest X-ray. Due to a malfunction in the X-Ray system's calibration,
he received a significantly higher dose of radiation than intended.
Investigation: The investigation revealed a faulty sensor in the X-ray machine's control panel,
which led to an incorrect radiation output. Imaging Dynamics Corp. issued a service bulletin to inspect and recalibrate all affected XR-5000 systems.
Outcome: Mr. Smith was informed of the overexposure and monitored for any potential long-term effects of the increased radiation dose.
knowledge = The immediate risk was considered low, but long-term risks could not be fully excluded.
"""

Preloading Knowledge

We’ll now create a simple function to preload the prepared knowledge into the model. This process uses Hugging Face’s dynamic caching mechanism (specifically key-value caching) to store the processed knowledge efficiently. This preloading step is crucial for Cache-Augmented Generation (CAG) as it allows the model to “remember” the knowledge and avoid redundant computations during inference.

The function will essentially take the prepared knowledge text as input and process it through the model once. The resulting key-value states from the attention layers are then stored in the cache. Subsequent queries can then leverage this cached information, significantly speeding up the generation process.

In essence, the function returns the “keys” and “values” that represent the pre-processed knowledge, ready to be used during the generation phase. This is how the model efficiently incorporates the external knowledge without having to reprocess it for every new query.

def preprocess_knowledge(
model,
tokenizer,
prompt: str) -> DynamicCache:
"""
Prepare knowledge kv cache for CAG.
Args:
model: HuggingFace model with automatic device mapping
tokenizer: HuggingFace tokenizer
prompt: The knowledge to preprocess, which is basically a prompt

Returns:
DynamicCache: KV Cache
"""
embed_device = model.model.embed_tokens.weight.device # check which device are used
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(embed_device)
past_key_values = DynamicCache()
with torch.no_grad():
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
output_attentions=False,
output_hidden_states=False)
return outputs.past_key_values

Preparing Knowledge and Creating Key-Value Cache Data

Before generating the key-value (KV) cache data, we need to format the prompt and provide instructions to the model. The structure of this prompt, including any specific instructions or special tokens, is crucial and depends heavily on the chosen model.

Different language models have different input requirements. Some models use unique special tokens (like <s>, [CLS], or <bos>) to denote the beginning of a sequence, separate different parts of the input, or signal specific tasks. Therefore, it's essential to tailor the prompt and instructions to the specific model you're using.

In our case, we’ll format the prompt and instructions according to the requirements of the model we’re using (presumably Llama-3.1-Instruct). This will ensure that the model correctly processes the knowledge and generates the appropriate KV cache data.

def prepare_kvcache(documents, answer_instruction: str = None):
# Prepare the knowledges kvcache

if answer_instruction is None:
answer_instruction = "Answer the question with a super short answer."

knowledges = f"""
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
You are an medical assistant for giving short answers
based on given reports.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Context information is bellow.
------------------------------------------------
{documents}
------------------------------------------------
{answer_instruction}
Question:
"""
# Get the knowledge cache
kv = preprocess_knowledge(model, tokenizer, knowledges)
kv_len = kv.key_cache[0].shape[-2]
print("kvlen: ", kv_len)
return kv, kv_len


knowledge_cache, kv_len = prepare_kvcache(documents =knowledge)
# kvlen: 610

After preloading knowledge into the key-value (KV) cache, we store its length. This is crucial because queries extend the KV cache. To maintain a consistent context of just the preloaded knowledge for subsequent queries, we truncate the KV cache back to its original length after each query. This ensures each query operates on the intended knowledge base, preventing unwanted interactions between queries.

Query Answering

Having preloaded our knowledge into the LLM’s key-value (KV) cache, we’re now ready to answer questions about the reports. A crucial first step is implementing a clean_up function. Already described above, this function will be responsible for restoring the KV cache to its original state (containing only the preloaded knowledge) after each query.

def clean_up(kv: DynamicCache, origin_len: int):
"""
Truncate the KV Cache to the original length.
"""
for i in range(len(kv.key_cache)):
kv.key_cache[i] = kv.key_cache[i][:, :, :origin_len, :]
kv.value_cache[i] = kv.value_cache[i][:, :, :origin_len, :]

This function handles the prediction process, which includes using the preloaded knowledge (stored in the KV cache) to answer queries:

def generate(
model,
input_ids: torch.Tensor,
past_key_values,
max_new_tokens: int = 300
) -> torch.Tensor:
"""
Generate text with greedy decoding.

Args:
model: HuggingFace model with automatic device mapping
input_ids: Input token ids
past_key_values: KV Cache for knowledge
max_new_tokens: Maximum new tokens to generate
"""

embed_device = model.model.embed_tokens.weight.device

origin_ids = input_ids
input_ids = input_ids.to(embed_device)

output_ids = input_ids.clone()
next_token = input_ids

with torch.no_grad():
for _ in range(max_new_tokens):
outputs = model(
input_ids=next_token,
past_key_values=past_key_values,
use_cache=True
)
next_token_logits = outputs.logits[:, -1, :]
next_token = next_token_logits.argmax(dim=-1).unsqueeze(-1)
next_token = next_token.to(embed_device)

past_key_values = outputs.past_key_values

output_ids = torch.cat([output_ids, next_token], dim=1)

if (next_token.item() in model.config.eos_token_id) and (_ > 0):
break
return output_ids[:, origin_ids.shape[-1]:]

Starting the Prediction Process

We are now ready to begin the prediction process. This involves using the preloaded knowledge, stored efficiently in the key-value (KV) cache, to generate answers to user queries

query = 'which Patient experienced issues with blood glucose meter, 
what was the problem ?'

clean_up(knowledge_cache, kv_len)
input_ids = tokenizer.encode(query, return_tensors="pt").to(model.device)
output = generate(model, input_ids, knowledge_cache)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True, temperature=None)
print(f"Response of the model:\n {generated_text}")
Response of the model:
assistant

Mr. Miller experienced issues with the blood glucose meter.
The problem was that it
provided falsely low readings due to a faulty batch of test strips

Conclusion

With the final code snippet, you can now test various questions to the model and it will generate responses based on the pre-cached knowledge. This story provided a basic and simplified overview of implementing Cache-Augmented Generation (CAG). For a deeper understanding, please refer to the original research paper and the associated repository.

This demonstration used a small knowledge base with a limited number of examples. However, if you’re working with a significantly larger dataset (e.g., more than 1,000 examples), preloading the model and generating the KV cache can become computationally expensive. In such cases, it’s highly recommended to store the generated KV cache data to disk. This allows you to load the precomputed cache directly, avoiding the need to regenerate it each time, which is crucial for scalability in large-scale applications. While not essential for this small-scale demonstration, this optimization is a necessity for practical, real-world implementations of CAG.

def write_kv_cache(kv: DynamicCache, path: str):
"""
Write the KV Cache to a file.
"""
torch.save(kv, path)

def read_kv_cache(path: str) -> DynamicCache:
"""
Read the KV Cache from a file.
"""
kv = torch.load(path, weights_only=True)
return kv

--

--

No responses yet