Optimizing Retrieval Augmentation with Dynamic Top-K Tuning for Efficient Question Answering

Training a cross-encoder to intelligently predict retrieval top-k, enhancing the precision and resource efficiency of question-answering systems.

Saurav Joshi
14 min readNov 3, 2023

In the realm of question-answering systems, the static nature of top-k retrieval settings poses a significant challenge, often leading to either information overload or scarcity, which in turn affects the efficiency and accuracy of the responses generated. This one-size-fits-all approach overlooks the nuanced demands of varying query complexities, resulting in suboptimal utilization of language model capabilities and unnecessary computational expenditure. Our solution introduces a dynamic paradigm shift through the training of a cross-encoder that adeptly adjusts the retrieval breadth in real-time. By evaluating the intricacy of each query, the system predicts the most appropriate top-k value for retrieval, ensuring that each question is met with a tailored, precise, and resource-efficient response. This not only streamlines the retrieval process but also significantly enhances the overall performance of the question-answering system.

Top-K tuning System Architecture. Image by author.

Here’s a more streamlined version of the process:

Training

  1. We segment each document into chunks and index these in a vector database.
  2. Iterating from k = 1 to 10, we retrieve the top k passages for each iteration and prompt the language model to generate answers based on a given question.
  3. Utilizing a custom RankCorrectnessEvaluator, we rank the ten candidate answers in relation to the reference answer and the query.
  4. We compile a training dataset comprising the query, document, and the optimal top k value for each query-document pair derived from the ranking step.
  5. This dataset is used to train a cross encoder, which predicts the top k value by learning from the query and document context.

Inference

  1. We begin by dividing the document into chunks and indexing these in the vector database.
  2. For a given document and query, we predict the top k value and input this into the retriever.
  3. The large language model receives the top k passages as context to generate the answer.
  4. We assess the effectiveness of static versus dynamic retrieval strategies by employing evaluation metrics such as CorrectnessEvaluator, SimilarityEvaluator, and TokenCounter.

Retrieval Augmented Generation

Retrieval-Augmented Generation (RAG) is a powerful technique that combines the strengths of information retrieval and language generation to enhance the quality of machine-generated text. By retrieving relevant document passages and using them as additional context, RAG models can produce more accurate, informative, and contextually relevant answers. This approach is particularly crucial for complex question-answering tasks where a deep understanding of a broad range of topics is necessary to generate coherent and factually correct responses. Moreover, since LLMs are trained up to a certain point in time, they lack the ability to access or incorporate post-training events or knowledge, making retrieval-augmented methods essential for up-to-date and expanded content understanding.

Weaviate Vector Database

A vector database is a specialized type of database designed to handle vector embeddings, which are high-dimensional representations of data points, typically generated by machine learning models. These embeddings capture the semantic relationships between data points, enabling efficient similarity searches. This is particularly important for retrieval-augmented generation (RAG) systems, as they rely on quickly finding the most relevant document passages from a large corpus to assist in generating accurate and contextually relevant responses. Weaviate is an open-source vector database. It allows you to store data objects and vector embeddings from your favorite ML-models, and scale seamlessly into billions of data objects. You can learn about Weaviate here.

LlamaIndex End-to-End Evaluation

LlamaIndex provides a robust framework for evaluating retrieval-augmented generation systems, ensuring that the entire pipeline — from data retrieval to final response generation — operates effectively and delivers accurate results. Utilizing LlamaIndex’s evaluation tools is crucial because they offer a comprehensive suite of metrics and evaluators that can diagnose and improve the performance of RAG systems, leading to more reliable and contextually relevant answers without the need for exhaustive manual review. You can learn more about the evaluation metrics here. For this experiment, I used —

  1. Correctness Evaluator: It evaluates the relevance and correctness of a generated answer against a reference answer.
  2. Embedding Similarity Evaluator: It evaluates the quality of a generated answer against a reference answer via semantic similarity.

Apart from these 2 metrics, I also used LLM total token count as an additional metric to asses the different strategies.

Implementation

Refer to my GitHub repo for the complete Jupyter notebook.

Lets begin by importing the necessary libraries.

import os
import re
import pickle
import openai
import tiktoken
import random
import ast
import time
import pandas as pd
import weaviate
import seaborn as sns
import matplotlib.pyplot as plt
from datasets import load_dataset
from llama_index.vector_stores import WeaviateVectorStore
from IPython.display import Markdown, display
from llama_index import QueryBundle
from llama_index.retrievers import BaseRetriever, VectorIndexRetriever
from llama_index import Document
from typing import Any, List, Optional
from tqdm.auto import tqdm
from llama_index import (
VectorStoreIndex,
SimpleDirectoryReader,
ServiceContext,
Response,
set_global_service_context
)
from llama_index.storage.storage_context import StorageContext
from llama_index.vector_stores import VectorStoreQuery
from llama_index.schema import NodeWithScore
from llama_index.embeddings import OpenAIEmbedding
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate
from llama_index import Document
from llama_index.evaluation import SemanticSimilarityEvaluator
from llama_index.embeddings import SimilarityMode
from llama_index.evaluation import CorrectnessEvaluator
from llama_index.evaluation.eval_utils import get_responses, get_results_df
from llama_index.callbacks import CallbackManager, TokenCountingHandler
from dotenv import load_dotenv

Data Preparation

The experiment is performed using the QASPER dataset which is a dataset for question answering on scientific research papers. In the below code, we are extracting the paper text, questions, and answers and creating the train and test set.

# Download QASPER dataset from HuggingFace https://huggingface.co/datasets/allenai/qasper
dataset = load_dataset("allenai/qasper")

# Split the dataset into train, validation, and test splits
train_dataset = dataset["train"]
validation_dataset = dataset["validation"]
test_dataset = dataset["test"]

random.seed(42) # Set a random seed for reproducibility

# Randomly sample 800 rows from the training split
train_sampled_indices = random.sample(range(len(train_dataset)), 800)
train_samples = [train_dataset[i] for i in train_sampled_indices]


# Randomly sample 100 rows from the test split
test_sampled_indices = random.sample(range(len(test_dataset)), 80)
test_samples = [test_dataset[i] for i in test_sampled_indices]

# Get full text paper data , questions on the paper from training samples of QASPER to generate training dataset for cross-encoder finetuning

# Utility function to get full-text of the research papers from the dataset
def get_full_text(sample: dict) -> str:
"""
:param dict sample: the row sample from QASPER
"""
title = sample["title"]
abstract = sample["abstract"]
sections_list = sample["full_text"]["section_name"]
paragraph_list = sample["full_text"]["paragraphs"]
combined_sections_with_paras = ""
if len(sections_list) == len(paragraph_list):
combined_sections_with_paras += title + "\t"
combined_sections_with_paras += abstract + "\t"
for index in range(0, len(sections_list)):
combined_sections_with_paras += str(sections_list[index]) + "\t"
combined_sections_with_paras += "".join(paragraph_list[index])
return combined_sections_with_paras

else:
print("Not the same number of sections as paragraphs list")

# utility function to extract list of questions from the dataset
def get_questions(sample: dict) -> List[str]:
"""
:param dict sample: the row sample from QASPER
"""
questions_list = sample["qas"]["question"]
return questions_list

# Utility function to extract answers from the dataset
def get_answers(sample: dict) -> List[str]:
"""
:param dict sample: the row sample from the train split of QASPER
"""
final_answers_list = []
answers = sample["qas"]["answers"]
for answer in answers:
local_answer = ""
types_of_answers = answer["answer"][0]
if types_of_answers["unanswerable"] == False:
if types_of_answers["free_form_answer"] != "":
local_answer = types_of_answers["free_form_answer"]
else:
local_answer = "Unacceptable"
else:
local_answer = "Unacceptable"

final_answers_list.append(local_answer)

return final_answers_list


doc_qa_dict_list = []
eval_doc_qa_answer_list = []

for train_sample in train_samples:
full_text = get_full_text(train_sample)
questions_list = get_questions(train_sample)
answers_list = get_answers(train_sample)
local_dict = {
"paper": full_text,
"questions": questions_list,
"answers": answers_list,
}
doc_qa_dict_list.append(local_dict)

for test_sample in test_samples:
full_text = get_full_text(test_sample)
questions_list = get_questions(test_sample)
answers_list = get_answers(test_sample)
local_dict = {
"paper": full_text,
"questions": questions_list,
"answers": answers_list,
}
eval_doc_qa_answer_list.append(local_dict)

Setup Weaviate

You can setup Weaviate vector database by creating an account and by following the instructions mentioned here to create and query the index.

# cloud
client = weaviate.Client(
url="https://....weaviate.network") #replace by url by your personal client url
client.schema.get()

weaviate_vector_store = WeaviateVectorStore(
weaviate_client=client, index_name="LlamaIndex"
)

Custom RankCorrectnessEvaluator

I went through LlamaIndex’s documentation and couldn’t find an evaluator that can rank multiples responses. The PairwiseComparisonEvaluator essentially compares 2 responses wrt to the reference answer and the query but in this experiment, we have multiple responses. Hence, I created RankCorrectnessEvaluator which ranks multiple responses associated with the respective k value during retrieval. You can look at the prompt for more details.

from typing import Any, Optional, Sequence, Union, Dict

from llama_index.evaluation.base import BaseEvaluator, EvaluationResult
from llama_index.indices.service_context import ServiceContext
from llama_index.prompts import (
BasePromptTemplate,
ChatMessage,
ChatPromptTemplate,
MessageRole,
PromptTemplate,
)
from llama_index.prompts.mixin import PromptDictType

RANKING_SYSTEM_TEMPLATE = """
You are an expert evaluation system for a question answering chatbot.

You are given the following information:
- a user query,
- a reference answer, and
- a list of generated answers, each associated with a different 'k' value.

Your job is to rank the generated answers in order of correctness and relevance to the user query and the reference answer, from best to worst.

Correctness should be judged based on:
- The number of overlapping tokens with the reference answer.
- The absence of incorrect information not present in the reference answer.
- The lack of unnecessary or irrelevant tokens that do not contribute to answering the query.

Please provide a ranked list of the 'k' values associated with the generated answers, starting with the 'k' value of the best answer and ending with the 'k' value of the worst answer.
Do not return answers in any other format.

You are given the following information:
- a user query,
- a reference answer, and
- generated answers.

User Query
query

Reference Answer
reference_answer

Generated Answers
k_1: answer_1
k_2: answer_2
...
k_10: answer_10

Based on the information provided and the criteria for correctness, rank the 'k' values from best to worst.
For example:
["k_7", "k_2", "k_9", ..., "k_3"]
"""

DEFAULT_USER_TEMPLATE = """
## User Query
{query}

## Reference Answer
{reference_answer}

## Generated Answers
{generated_answers}
"""

DEFAULT_EVAL_TEMPLATE = ChatPromptTemplate(
message_templates=[
ChatMessage(role=MessageRole.SYSTEM, content=RANKING_SYSTEM_TEMPLATE),
ChatMessage(role=MessageRole.USER, content=DEFAULT_USER_TEMPLATE),
]
)

class RankCorrectnessEvaluator(BaseEvaluator):
"""Rank correctness evaluator.

Evaluates and ranks the correctness of multiple generated answers for a question answering system.
This evaluator depends on `reference` answer to be provided, in addition to the
query string and multiple response strings.

It outputs a ranked list of 'k' values associated with the generated answers.

Args:
service_context (Optional[ServiceContext]): Service context.
eval_template (Optional[Union[BasePromptTemplate, str]]):
Template for the evaluation prompt.
"""

def __init__(
self,
service_context: Optional[ServiceContext] = None,
eval_template: Optional[Union[BasePromptTemplate, str]] = None,
) -> None:
self._service_context = service_context or ServiceContext.from_defaults()

self._eval_template: BasePromptTemplate
if isinstance(eval_template, str):
self._eval_template = PromptTemplate(eval_template)
else:
self._eval_template = eval_template or DEFAULT_EVAL_TEMPLATE

def _get_prompts(self) -> PromptDictType:
"""Get prompts."""
return {
"eval_template": self._eval_template,
}

def _update_prompts(self, prompts: PromptDictType) -> None:
"""Update prompts."""
if "eval_template" in prompts:
self._eval_template = prompts["eval_template"]

async def aevaluate(
self,
query: Optional[str] = None,
responses: Optional[Dict[str, str]] = None,
reference: Optional[str] = None,
**kwargs: Any,
) -> EvaluationResult:
del kwargs # Unused

if query is None or responses is None or reference is None:
raise ValueError("query, responses, and reference must be provided")

generated_answers_str = "\n".join(
[f"{k}: {answer}" for k, answer in responses.items()]
)

eval_response = await self._service_context.llm_predictor.apredict(
prompt=self._eval_template,
query=query,
generated_answers=generated_answers_str,
reference_answer=reference,
)

return EvaluationResult(
query=query,
response=eval_response
)

train_evaluator_service_context = ServiceContext.from_defaults(llm=OpenAI("gpt-4"))
train_evaluator = RankCorrectnessEvaluator(service_context=train_evaluator_service_context)

Retrieval Loop

Here, we are iterating from k=1 to 10 and storing all the responses in a dictionary with keys being the k value and value being the corresponding response. Next, we are sending these responses to the RankCorrectness Evaluator module to return the k value of the most relevant or correct response.

topk_training_dataset = []

for paper in tqdm(doc_qa_dict_list[:100]):
try: ## safety against any openai error
questions_list = paper["questions"]
documents = [Document(text=paper["paper"])]
reference_answers_list = paper["answers"]

assert len(questions_list) == len(reference_answers_list)

for question, reference_answer in zip(questions_list, reference_answers_list):
responses = {}

if reference_answer == "Unacceptable":
continue

for k_val in tqdm(range(1, 11)):

service_context = ServiceContext.from_defaults(chunk_size=512)
node_parser = service_context.node_parser
nodes = node_parser.get_nodes_from_documents(documents)

storage_context = StorageContext.from_defaults(vector_store=weaviate_vector_store)
storage_context.docstore.add_documents(nodes)
weaviate_vector_index = VectorStoreIndex(nodes, storage_context=storage_context)

weaviate_vector_retriever = VectorIndexRetriever(index=weaviate_vector_index, similarity_top_k=k_val)
query_engine = RetrieverQueryEngine.from_args(weaviate_vector_retriever)
response = query_engine.query(question)

responses["k_"+str(k_val)] = response.response

# rankcorrectness evaluator
result = await train_evaluator.aevaluate(
query=question,
responses=responses,
reference=reference_answer,
)
try:
result = result.response
eval_response_list = ast.literal_eval(result)
ranked_k_values = [int(k.split('_')[1]) for k in eval_response_list]
except:
try:
list_pattern = r'\["k_[0-9]+(?:", "k_[0-9]+)*"\]'
match = re.search(list_pattern, result.response)
if match:
list_str = match.group(0)
eval_response_list = ast.literal_eval(list_str)
ranked_k_values = [int(k.split('_')[1]) for k in eval_response_list]
else:
continue
except:
continue

# insert train dataset
best_response_insertion = (paper['paper'], question, str(ranked_k_values[0]))
topk_training_dataset.append(best_response_insertion)
with open('training_tuples.pkl', 'wb') as file:
pickle.dump(topk_training_dataset, file)
time.sleep(10)
except:
pass

Summarize Document

The cross encoder bert model max input length is 512 tokens. Hence, I am summarizing the long context paper using the below prompt. Then I am feeding the summarized document and the query alongwith the label i.e the ground truth top k value to the cross encoder for training.

system_prompt = f"""Your task is to summarize the following document into a concise version of approximately 500 words. 
Focus on capturing the main themes, essential points, and key arguments presented. Omit any extraneous details or repetitive
information to create a clear, coherent, and comprehensive summary that conveys the document's core message and intent.
Please ensure the summary is well-structured, with a clear beginning, middle, and end, and maintains the original
document's tone and perspective.
"""

def get_summary(document):
res = openai.ChatCompletion.create(
model="gpt-3.5-turbo-16k",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": document}
]
)
return res

final_topk_training_dataset = []
for doc, qst, topk in tqdm(loaded_topk_training_dataset):
response = get_summary(doc)
summarized_doc = response["choices"][0]["message"]["content"]
final_topk_training_dataset.append((summarized_doc, qst, topk))

Train Cross Encoder

Now, we setup the training pipeline for the BERT-based cross-encoder model, designed to classify sequences into one of ten labels.

import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset

class CEDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels

def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(int(self.labels[idx]) - 1)
return item

def __len__(self):
return len(self.labels)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

texts = [(doc, qst) for doc, qst, topk in loaded_final_topk_training_dataset]
labels = [topk for doc, qst, topk in loaded_final_topk_training_dataset]
encodings = tokenizer.batch_encode_plus(texts, truncation=True, padding=True, max_length=512)
ce_dataset = CEDataset(encodings, labels)

num_labels = 10
ce_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)

training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=200,
per_device_train_batch_size=5,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=20,
)

# Save the model
ce_model.save_pretrained('./ce_model')
tokenizer.save_pretrained('./ce_model')
Cross encoder training loss. Image by author.

Baseline RAG Evaluation

Here, we are setting up the evaluation metrics. But most importantly, we are going to asses the performance of these metrics for strategies where the top_k is static and fixed in a query engine specifically in the retriever for every question versus our method where we predict the top_k value individually for each question based on the question and the document.

gpt4 = OpenAI(temperature=0, model="gpt-4")
service_context_gpt4 = ServiceContext.from_defaults(llm=gpt4)

evaluator_similarity = SemanticSimilarityEvaluator()
evaluator_gpt4_correctness = CorrectnessEvaluator(service_context=service_context_gpt4)

token_counter = TokenCountingHandler(
tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode
)
callback_manager = CallbackManager([token_counter])
llm = OpenAI()
service_context = ServiceContext.from_defaults(
llm=llm, callback_manager=callback_manager, embed_model="local"
)
set_global_service_context(service_context)
similarity_scores_list = []
correctness_scores_list = []
total_llm_token_list = []
baseline_dict_list = []

for index, row in df_test.iterrows():
documents = [Document(text=row["paper"])]
query_list = row["questions"]
reference_answers_list = row["answers"]
number_of_accepted_queries = 0

vector_index = VectorStoreIndex.from_documents(documents)
query_engine = vector_index.as_query_engine(similarity_top_k=5)

assert len(query_list) == len(reference_answers_list)
similarity_local_score = 0
correctness_local_score = 0
total_llm_token_local = 0

for index in range(0, len(query_list)):
query = query_list[index]
reference = reference_answers_list[index]

if reference != "Unacceptable":
number_of_accepted_queries += 1

response = str(query_engine.query(query))

baseline_dict = {
"query": query,
"response": response,
"reference": reference,
}
baseline_dict_list.append(baseline_dict)

similarity_eval_result = await evaluator_similarity.aevaluate(
response=response, reference=reference
)

correctness_eval_result = await evaluator_gpt4_correctness.aevaluate(
query=query,
response=response,
reference=reference,
)

similarity_score = similarity_eval_result.score
correctness_score = correctness_eval_result.score
total_llm_token = int(token_counter.total_llm_token_count)

similarity_local_score += similarity_score
correctness_local_score += correctness_score
total_llm_token_local += total_llm_token

token_counter.reset_counts()
else:
pass

if number_of_accepted_queries > 0:
avg_similarity_local_score = (
similarity_local_score / number_of_accepted_queries
)
similarity_scores_list.append(avg_similarity_local_score)

avg_correctness_local_score = (
correctness_local_score / number_of_accepted_queries
)
correctness_scores_list.append(avg_correctness_local_score)

avg_total_llm_token_local = (
total_llm_token_local / number_of_accepted_queries
)
total_llm_token_list.append(avg_total_llm_token_local)


overall_similarity_average_score = sum(similarity_scores_list) / len(
similarity_scores_list
)
overall_correctness_average_score = sum(correctness_scores_list) / len(
correctness_scores_list
)
overall_total_llm_token_average = sum(total_llm_token_list) / len(
total_llm_token_list
)
df_responses = pd.DataFrame(baseline_dict_list)
df_responses.to_csv("Baseline_Responses_k_1.csv")

Trained Top-k RAG Evaluation

The only difference during inference here is that we are calling the BERT-based cross-encoder model for predicting the top_k value.

model_path = './ce_model'  
loaded_ce_tokenizer = BertTokenizer.from_pretrained(model_path)
loaded_ce_model = BertForSequenceClassification.from_pretrained(model_path)
loaded_ce_model.eval()

def predict_top_k(document, question):
inputs = loaded_ce_tokenizer.encode_plus(
question,
document,
add_special_tokens=True,
return_tensors="pt",
max_length=512,
truncation=True,
padding="max_length"
)
with torch.no_grad():
outputs = loaded_ce_model(**inputs)
prediction = torch.argmax(outputs.logits, dim=-1)
predicted_top_k = prediction.item() + 1
return predicted_top_k
similarity_scores_list = []
correctness_scores_list = []
total_llm_token_list = []
baseline_dict_list = []

for index, row in df_test.iterrows():
documents = [Document(text=row["paper"])]
query_list = row["questions"]
reference_answers_list = row["answers"]
number_of_accepted_queries = 0

assert len(query_list) == len(reference_answers_list)
similarity_local_score = 0
correctness_local_score = 0
total_llm_token_local = 0

vector_index = VectorStoreIndex.from_documents(documents)
response = get_summary(row["paper"])
summarized_paper = response["choices"][0]["message"]["content"]

for index in range(0, len(query_list)):
query = query_list[index]
reference = reference_answers_list[index]

predicted_top_k = predict_top_k(summarized_paper, query)
query_engine = vector_index.as_query_engine(similarity_top_k=predicted_top_k)

if reference != "Unacceptable":
number_of_accepted_queries += 1

response = str(query_engine.query(query))

baseline_dict = {
"query": query,
"response": response,
"reference": reference,
}
baseline_dict_list.append(baseline_dict)

similarity_eval_result = await evaluator_similarity.aevaluate(
response=response, reference=reference
)

correctness_eval_result = await evaluator_gpt4_correctness.aevaluate(
query=query,
response=response,
reference=reference,
)

similarity_score = similarity_eval_result.score
correctness_score = correctness_eval_result.score
total_llm_token = int(token_counter.total_llm_token_count)

similarity_local_score += similarity_score
correctness_local_score += correctness_score
total_llm_token_local += total_llm_token

token_counter.reset_counts()
else:
pass

if number_of_accepted_queries > 0:
avg_similarity_local_score = (
similarity_local_score / number_of_accepted_queries
)
similarity_scores_list.append(avg_similarity_local_score)

avg_correctness_local_score = (
correctness_local_score / number_of_accepted_queries
)
correctness_scores_list.append(avg_correctness_local_score)

avg_total_llm_token_local = (
total_llm_token_local / number_of_accepted_queries
)
total_llm_token_list.append(avg_total_llm_token_local)

overall_similarity_average_score = sum(similarity_scores_list) / len(
similarity_scores_list
)
overall_correctness_average_score = sum(correctness_scores_list) / len(
correctness_scores_list
)
overall_total_llm_token_average = sum(total_llm_token_list) / len(
total_llm_token_list
)
df_responses = pd.DataFrame(baseline_dict_list)
df_responses.to_csv("Trained_topk.csv")

Here is a distribution plot of the predicted top-k value for queries from the test set.

Distribution of Predicted Top-k. Image by author.

Some questions from the test set which are predicted to have top-k 1:

What datasets were used in this work?
Which datasets do they experiment on?
Which language pairs do they evaluate on?
Which domain are the conversations in?

And here are questions from predicted top-k 8:

Did the survey provide insight into features commonly found to be predictive of abusive content on online platforms?
What are the opportunities presented by the use of Semantic Web technologies in Machine Translation?
Which other units of text do they experiment with (apart from BPE and ortographic syllables)?
Why is improvement on OntoNotes significantly smaller compared to improvement on WNUT 2017?

Result

The preliminary results from the project are indeed promising. I took 100 documents for training (1/8 of the actual training data) which have multiple questions and for testing I took 25 documents (~1/3 of the actual testing data) again which had multiple questions. This limitation was due to my own personal resource constraint. The trained top-k model demonstrates a competitive similarity score of 0.835831 suggesting that the dynamic selection of top-k does not compromise the relevance of the retrieved information. Notably, the correctness score of 2.928571 indicates a high level of accuracy in the answers generated, rivaling even the baseline k_10 model. This is achieved while significantly reducing the average total LLM token count to 1857.238095, which is a substantial improvement in computational and resource efficiency compared to the baseline k_5 and k_10 models. We all know how expensive GPT-4 is and so providing unnecessary context increases the number of input tokens and in turn increases the cost. These results underscore the effectiveness of the cross-encoder in tailoring the retrieval process to the complexity of the query, thereby optimizing both the precision of the answers and the resource expenditure of the system.

Summary

In summary, the project represents a way of moving away from static top-k retrieval settings by introducing a dynamic approach that adapts to the complexity of individual queries. A cross-encoder is trained to predict the most suitable top-k value for retrieval, ensuring that the large language model receives an optimal amount of context for each question. This method not only preserves the relevance and correctness of the information retrieved, as evidenced by the promising preliminary results, but also markedly reduces the computational resources required. The project showcases the potential for a more nuanced and intelligent application of retrieval-augmented generation, leading to a question-answering system that is both precise in its responses and economical in its use of language model tokens.

References

--

--