From Search to Synthesis: Enhancing RAG with BM25 and Reciprocal Rank Fusion
In this blog, we will enhance RAG with BM25, Reciprocal Rank Fusion, and Sparse Priming Representation
Let us begin the journey with an brief refresher on Retrieval-Augmented Generation (RAG).
In the realm of artificial intelligence, large language models (LLMs) stand as monumental achievements, showcasing an uncanny ability to mimic human-like text generation. However, their prowess is not without limitations. At times, these behemoths exhibit a remarkable grasp of the intricacies of language, delivering precise answers to complex queries. Yet, in other instances, they falter, spewing out arbitrary facts from their training data or sounding utterly clueless. The inconsistency stems from a fundamental truth: LLMs understand the statistical relationships between words but lack a genuine comprehension of their meaning. They are, in essence, merely regurgitating patterns learned during training without a grounded understanding of the information they possess.
Enter Retrieval-Augmented Generation (RAG) — a novel AI framework crafted to address the shortcomings of LLMs by anchoring them to external knowledge sources. Unlike traditional LLMs, which solely rely on pre-trained representations, RAG introduces a two-step dance of retrieval and generation to the language processing endeavor. Initially, a retriever scours a corpus to fetch pertinent documents or passages. Following this, a generator, armed with the retrieved information and the original query, concocts a well-informed response. This synergy of retrieval and generation aims to bestow upon LLMs a semblance of awareness, enabling them to pull from real-world facts when crafting responses.
The implementation of RAG in an LLM-based question-answering system heralds a twofold boon: it not only facilitates access to the latest and most reliable facts but also extends a transparency olive branch to users by providing insight into the sources backing the model’s assertions. This dual advantage addresses the trust deficit often associated with LLMs, paving the way for more accurate, verifiable, and ultimately, trustworthy machine-generated responses.
The promise of RAG extends beyond merely patching the inconsistencies of LLMs. It represents a stride towards more contextually grounded, informative, and reliable machine intelligence. As LLMs continue to burgeon both in size and capability, integrating frameworks like RAG could very well be the linchpin in transitioning from statistically driven text regurgitation to genuinely insightful and reliable machine-generated discourse. Through the lens of RAG, we begin to glimpse a future where AI not only talks the talk but walks the walk, providing responses that are as informed as they are articulate.
Let us enhance RAG with the following tools
BM25
BM25 is a well-regarded retrieval algorithm known for its effectiveness in ranking documents based on their relevance to a given query. It is a robust ranking function employed by search engines to sift through a plethora of documents, ranking them based on their relevance to a user’s query. It smartly calculates a score for each document, considering not just the presence but the frequency of query terms, alongside their rarity across the dataset. Introducing BM25 to the retrieval stage of RAG could enhance the model’s ability to fetch relevant documents, which in turn could improve the accuracy and relevance of generated responses.
Reciprocal Rank Fusion (RRF)
Reciprocal Rank Fusion (RRF) is a technique used in information retrieval to combine the results from multiple searches into a single ranked list of results. This method is particularly useful when the searches are conducted across different datasets or using different search algorithms, and there is a need to amalgamate the results into a coherent ranking. We will use RRF in the retrieval stage by retrieving documents for multiple queries and using reciprocal rank fusion to reorder the documents.
For more details on RRF please refer to this paper https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
Sparse Priming Representation (SPR)
Sparse Priming Representation (SPR) embodies a refined methodology that ingeniously integrates the notion of priming within the domain of Large Language Models (LLMs). Contrary to the application of verbose or overly detailed inputs, SPR strategically utilizes a succinct and focused set of cues aimed at triggering the pertinent regions within an LLM’s latent space. This strategic maneuver facilitates a more targeted interaction with the model, ensuring that the desired information or response is elicited with greater accuracy. The term “sparse” encapsulates the essence of this approach, underlining the efficiency and precision achieved by minimizing the input while maximizing the relevancy of the model’s engagement with the latent space. Through this lean yet potent priming, SPR heralds a pathway toward harnessing the intricate potential of LLMs in a more controlled and precise manner. We will SPR to compress the documents that are ranked by RRF before sending them to the LLM to form the final answer.
Reference for SPR — https://github.com/daveshap/SparsePrimingRepresentations
Let’s dive into the implementation
These are the dependencies that need to be installed. I am using python 3.11
langchain==0.0.318
pydantic==1.10.12
typing-extensions==4.7.1
faiss-cpu==1.7.4
rank-bm25==0.2.2
I will use the following PDF for demonstration — https://arxiv.org/pdf/2310.08560.pdf
You can copy each of the below code blocks one by one in Jupyter Notebook and run them successively.
from typing import List, Dict
from dataclasses import dataclass
from langchain.agents import AgentType
from pydantic import BaseModel, validator, Extra
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
@dataclass
class Llm():
llm: BaseChatModel
llm_name: str
llm_args: dict
class Config:
arbitrary_types_allowed = True
def __str__(self):
return f"llm: {self.llm_name} \n llm_args: {self.llm_args}"
class OpenAI(ChatOpenAI):
model_name: str = "gpt-3.5-turbo"
temperature: float = 0
openai_api_key: str
streaming: bool = True
@staticmethod
def get_display_name():
return "OpenAI"
@staticmethod
def get_valid_model_names():
valid_model_names = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4", "gpt-4-0613", "gpt-4-32k-0613", "gpt-4-32k"}
return valid_model_names
@validator("model_name")
def validate_model_name(cls, request):
valid_model_names = cls.get_valid_model_names()
if request not in valid_model_names:
raise ValueError(f"invalid model name given - {request} , valid ones are {valid_model_names}")
return request
class LangchainLlms:
def __init__(self):
self.__llms = {
"OpenAI": {
"llm": OpenAI,
"schema": OpenAI
}
}
def get_llm(self, llm_name: str, **llm_kwargs) -> Llm:
if llm_name not in self.__llms:
raise ValueError(f"invalid llm name given - {llm_name} , must be one of {list(self.__llms.keys())}")
llm = self.__llms[llm_name]["llm"]
print(llm_kwargs)
llm_args = self.__llms[llm_name]["schema"](**llm_kwargs)
llm_obj = llm(**dict(llm_args))
return Llm(llm=llm_obj,llm_args=dict(llm_args), llm_name=llm_name)
from faiss import IndexFlatL2
from langchain.vectorstores import FAISS
from langchain.docstore import InMemoryDocstore
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.base import VectorStore
openai_api_key = "<YOUR OPEN API KEY>"
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
vector_db = FAISS(
embedding_function=embeddings,
index=IndexFlatL2(1536),
docstore=InMemoryDocstore({}),
index_to_docstore_id={},
)
pdf_path = "mem_gpt.pdf"
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=5000,
chunk_overlap=1000,
length_function=len,
)
doc_loader = PyPDFLoader(pdf_path)
pages = doc_loader.load_and_split()
docs = text_splitter.split_documents(pages)
vector_db.add_documents(docs)
bm25_corpus = [doc.page_content for doc in docs]
from typing import List, Dict
from rank_bm25 import BM25Okapi
class VectorDbWithBM25:
def __init__(self):
self.__vector_db = vector_db
self.__vector_db.add_documents(docs)
self.__bm25_corpus = bm25_corpus
tokenized_corpus = [doc.split(" ") for doc in bm25_corpus]
self.__bm25 = BM25Okapi(tokenized_corpus)
def vector_db_search(self, query: str, k=3) -> Dict[str, float]:
search_result = dict()
docs_and_scores = self.__vector_db.similarity_search_with_relevance_scores(query=query, k=k)
for doc, score in docs_and_scores:
search_result[doc.page_content] = score
return {doc: score for doc, score in sorted(search_result.items(), key=lambda x: x[1], reverse=True)}
def bm25_search(self, query: str, k=3) -> Dict[str, float]:
tokenized_query = query.split(" ")
doc_scores = self.__bm25.get_scores(tokenized_query)
docs_with_scores = dict(zip(self.__bm25_corpus, doc_scores))
sorted_docs_with_scores = sorted(docs_with_scores.items(), key=lambda x: x[1], reverse=True)
return dict(sorted_docs_with_scores[:k])
def combine_results(self, vector_db_search_results: Dict[str, float],
bm25_search_results: Dict[str, float]) -> Dict[str, float]:
def normalize_dict(input_dict):
epsilon = 0.05
min_value = min(input_dict.values())
max_value = max(input_dict.values())
a, b = 0.05, 1
if max_value == min_value:
return {k: b if max_value > 0.5 else a for k in input_dict.keys()}
return {k: a + ((v - min_value) / (max_value - min_value)) * (b - a) for k, v in input_dict.items()}
norm_vector_db_search_results = normalize_dict(vector_db_search_results)
norm_bm25_search_results = normalize_dict(bm25_search_results)
# Combine the dictionaries
combined_dict = {}
for k, v in norm_vector_db_search_results.items():
combined_dict[k] = v
for k, v in norm_bm25_search_results.items():
if k in combined_dict:
combined_dict[k] = max(combined_dict[k], v)
else:
combined_dict[k] = v
return combined_dict
def search(self, query: str, k=3, do_bm25_search=True) -> Dict[str, float]:
vector_db_search_results = self.vector_db_search(query, k=k)
if do_bm25_search:
bm25_search_results = self.bm25_search(query, k=k)
if bm25_search_results:
combined_search_results = self.combine_results(vector_db_search_results, bm25_search_results)
sorted_docs_with_scores = sorted(combined_search_results.items(), key=lambda x: x[1], reverse=True)
return dict(sorted_docs_with_scores)
return vector_db_search_results
vector_db_with_bm25 = VectorDbWithBM25()
langchain_llm = LangchainLlms()
import re
import asyncio
from typing import Dict, List
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage, LLMResult)
def remove_bullet_points(text):
lines = text.strip().split('\n')
cleaned_lines = [re.sub(r'^[\d\.\-\*\s]+', '', line).strip() for line in lines]
return cleaned_lines
class RagFusion:
def __init__(self, vector_store):
self.__vectorstore = vector_store
self.__llm = langchain_llm.get_llm("OpenAI",
openai_api_key=openai_api_key, model_name="gpt-3.5-turbo-16k").llm
async def generate_queries(self, query: str) -> List[str]:
system_prompt = "You are a helpful assistant that generates multiple search queries based on a single input query."
human_message = f"Generate 4 search queries related to: {query}"
messages = []
messages.append(SystemMessage(content=system_prompt))
messages.append(HumanMessage(content=human_message))
response = await self.__llm.agenerate(messages=[messages])
if response and isinstance(response, LLMResult):
generations = response.flatten()
llm_result = generations[0].generations[0][0].text
parsed_result = remove_bullet_points(llm_result)
return parsed_result
return []
async def rewrite_query(self, query: str) -> str:
prompt = f"""Provide a better search query for web search engine to answer the given question. End \
the queries with ’**’. Question: ``` {query} ``` """
messages = [HumanMessage(content=prompt)]
response = await self.__llm.agenerate(messages=[messages])
if response and isinstance(response, LLMResult):
generations = response.flatten()
llm_result = generations[0].generations[0][0].text
return llm_result.strip("**")
return ""
def vector_db_search(self, query: str, k=3) -> Dict[str, float]:
search_result = dict()
docs_and_scores = self.__vectorstore.search(query, do_bm25_search=True, k=k)
for doc, score in docs_and_scores.items():
search_result[doc] = score
return {doc: score for doc, score in sorted(search_result.items(), key=lambda x: x[1], reverse=True)}
def retrieve_multiple_responses(self, similar_queries: List[str], k=3) -> Dict[str, Dict[str, float]]:
all_results = dict()
for query in similar_queries:
search_results = self.vector_db_search(query, k=k)
all_results[query] = search_results
return all_results
def reciprocal_rank_fusion(self, search_results_dict, k=60) -> Dict[str, float]:
# k=60 is taken from the paper https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf
fused_scores = {}
for query, doc_scores in search_results_dict.items():
for rank, (doc, score) in enumerate(sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)):
if doc not in fused_scores:
fused_scores[doc] = 0
previous_score = fused_scores[doc]
fused_scores[doc] += 1 / (rank + k)
reranked_results = {doc: score for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)}
return reranked_results
async def run_spr(self, content_to_compress: str) -> str:
spr_system_prompt = """# MISSION
You are a Sparse Priming Representation (SPR) writer. An SPR is a particular kind of use of language for advanced NLP, NLU, and NLG tasks, particularly useful for the latest generation Large Language Models (LLMs). You will be given information by the USER which you are to render as an SPR.
# THEORY
LLMs are a kind of deep neural network. They have been demonstrated to embed knowledge, abilities, and concepts, ranging from reasoning to planning, and even to theory of mind. These are called latent abilities and latent content, collectively referred to as latent space. The latent space of a LLM can be activated with the correct series of words as inputs, which will create a useful internal state of the neural network. This is not unlike how the right shorthand cues can prime a human mind to think in a certain way. Like human minds, LLMs are associative, meaning you only need to use the correct associations to "prime" another model to think in the same way.
# METHODOLOGY
Render the input as a distilled list of succinct statements, assertions, associations, concepts, analogies, and metaphors. The idea is to capture as much, conceptually, as possible but with as few words as possible. Write it in a way that makes sense to you, as the future audience will be another language model, not a human."""
human_message = f"this is the input content that you need to distill - ``` {content_to_compress} ``` "
messages = []
messages.append(SystemMessage(content=spr_system_prompt))
messages.append(HumanMessage(content=human_message))
response = await self.__llm.agenerate(messages=[messages])
if response and isinstance(response, LLMResult):
generations = response.flatten()
llm_result = generations[0].generations[0][0].text
return llm_result
return ""
async def form_final_result(self, spr_results: List[str], original_query: str) -> str:
spr_results = "\n ****************** \n".join(spr_results)
prompt = f"""Answer the user's question based only on the following context:
<context>
{spr_results}
</context>
Question: ``` {original_query} ```
DO NOT MAKE UP ANY FALSE INFORMATION. USE ONLY THE GIVEN CONTEXT"""
messages = [HumanMessage(content=prompt)]
response = await self.__llm.agenerate(messages=[messages])
if response and isinstance(response, LLMResult):
generations = response.flatten()
llm_result = generations[0].generations[0][0].text
return llm_result
return ""
async def arun(self, query: str, rewrite_original_query=False):
if rewrite_original_query:
rephrased_query = await self.rewrite_query(query)
if rephrased_query:
query = rephrased_query
print("rephrased_query: ", rephrased_query)
print()
similar_queries_list = await self.generate_queries(query)
print("similar_queries_list: ", similar_queries_list)
print()
if similar_queries_list:
search_results = self.retrieve_multiple_responses(similar_queries_list)
reranked_results = self.reciprocal_rank_fusion(search_results)
# here I am using all the reranked results, you can select the top N
spr_tasks = []
spr_results = []
for result, score in reranked_results.items():
spr_task = asyncio.create_task(self.run_spr(result))
spr_tasks.append(spr_task)
done, pending = await asyncio.wait(spr_tasks, timeout=180)
for done_task in done:
if done_task.exception() is None:
result = done_task.result()
spr_results.append(result)
for pending_task in pending:
pending_task.cancel()
if spr_results:
for spr_content in spr_results:
print(spr_content)
print()
print("*" * 100)
final_result = await self.form_final_result(spr_results, query)
print("final result: ")
print(final_result)
rag = RagFusion(vector_store=vector_db_with_bm25)
The method “arun” from the RagFusion class is to be called for a query to be answered. Let me run through the steps that are being executed in this method.
- The “rewrite_query” method rephrases the given user query
- The “generate_queries” method generates queries similar to the query generated in step 1.
- The “retrieve_multiple_responses” method retrieves documents for all the queries generated in step 2.
- The “reciprocal_rank_fusion” method ranks all the documents from step 3.
- The ranked documents from step 4 are run through SPR in the “run_spr” method.
- The final response is generated by the “form_final_result” method.
Let’s run some queries.
await rag.arun("what is memgpt", rewrite_original_query=True)
Response —
MemGPT is a system that manages memory tiers to provide extended context
within the limited window of a large language model (LLM). MemGPT
is inspired by hierarchical memory systems in operating systems and aims
to overcome the limitations of traditional LLMs by effectively managing memory
and enabling analysis of large documents and creating conversational agents
that remember and evolve.
await rag.arun("how does memgpt work", rewrite_original_query=True)
Response —
MemGPT is a language model that incorporates a hierarchical memory system and
functions to manage its own memory. It uses a memory hierarchy and control
flow similar to traditional operating systems. The main context, similar
to RAM in an operating system, is input to the LLM processor, which produces
text interpreted by a parser. The processor can generate function calls to
transfer data between the main context and the external context, similar to
disk memory in an operating system. MemGPT allows for the exchange of
information between context windows, similar to the concept of "main memory"
in operating systems, and external storage. It effectively manages memory
tiers to provide extended context within the LLM's limited window. MemGPT
can query past conversation history in its Recall Memory to answer specific
questions about prior conversations. It overcomes the limitations of finite
context and outperforms existing LLM-based approaches in document analysis
and conversational agent tasks.
distracted_query = "man that sam bankman fried trial was crazy! how does memgpt work?"
await rag.arun(distracted_query, rewrite_original_query=True)
Response —
MemGPT works by incorporating a hierarchical memory system and functions
to manage its own memory. The LLM processor takes the main context as
input and produces text interpreted by a parser, resulting in a yield or a
function call. MemGPT uses functions to transfer data between the main context
and the external context. When the processor generates a function call,
it can request control ahead of time to chain together functions.
When yielding, the processor is paused until the next external event occurs.
MemGPT allows for the exchange of information between context windows,
similar to the concept of "main memory" in operating systems, and external
storage. The design of MemGPT enables repeated context modifications during a
single task, allowing the agent to effectively utilize its limited context.
RAG stands as a testament to the fusion of information retrieval and large language models, redefining how we perceive the search experience. However, like all technologies, RAG isn’t without its shortcomings. The integration of tools like BM25, Reciprocal Rank Fusion, and Sparse Priming Representation serves as an innovative approach to magnify RAG’s capabilities and address its inherent limitations. The dynamic landscape of LLM, prompt engineering, NLP, and machine learning promises that this is just the beginning, and I eagerly anticipate the wave of groundbreaking ideas that the future holds for us.
Thank you for investing the time to read through this article. I trust you’ve found the insights and code examples beneficial. While the provided code serves as a foundational framework, there are avenues for enhancement.
Your feedback is invaluable to the continuous improvement of this resource, and I welcome any comments or suggestions you may have. Thank you once again for your time and engagement.