How to Make Your RAG Less Distracted?

Mathieu Martial
7 min readJul 18, 2024

--

Introduction

In this article, we investigate the optimal number of Q&A (from HotPotQA) required for fine-tuning Mistral Instruct-7b in order to enhance a Retrieval-Augmented Generation (RAG) system. We study the impact of fine-tuning with distracting context on model performance. Our goal is to identify what constitues an efficient and cost-effective strategy for fine-tuning in that context.

Methodology

In our experiments, we simulate a RAG by freezing the retrieval component and directly feeding context to the Large Language Model (LLM). The model we fine-tune is Mistral Instruct-7B, and the dataset we use is HotPotQA in its distracting setting.

Dataset and metrics

HotPotQA is a question-answering dataset on general trivia. It has two settings: fullwiki and distracting. In this work, we use the latter, described as a task where “a question-answering system reads 10 paragraphs to provide an answer (Ans) to a question. They must also justify these answers with supporting facts (Sup)”.

Figure 1: An example of the multi-hop questions in HOTPOTQA. We also highlight the supporting facts in blue italics, which are also part of the dataset.

In their paper, two metrics are used: Joint-F1 and Joint-EM (Exact Match).

For the sake of simplicity, we will only deal with the answers and not the supporting facts, and we thus use the regular F1 and EM definitions on answer generation.

Baseline

Our baseline for this experiment is Mistral Instruct Small (7B), which will be referred to as the baseline model in this article.

For the test set, we simulate having a RAG by including in the prompt k paragraphs that come with the distracting setting of HotPotQA. They are all closely related to the query and can thus be seen as the top-k documents retrieved by a ranker. The test set contains 1405 queries. Here is the prompt template we used:

Context information is below
---------------------
{context}
---------------------
Given the context information and not prior knowledge,
answer the query. If the context does not help answer the question, then don't answer.
Your answers must be as short as possible, don't make sentences! The shorter the better.
Query: {query}

context is the aggregation of all k paragraphs and query the question. Note that this prompt was inspired from the one given in Mistral’s documentation.

Let’s see some of the baseline model’s answers on the test set with all 10 paragraphs used as context:

  • The baseline model achieves a F1 score of 0.59 and an exact-match score of 0.42;
  • It exhibits some formatting issues regarding the task at hand, as shown by lines 3 and 4 of Table 1, where it got the answer right but failed to match the ground truth;
  • The model also got many questions wrong as it failed to pick up the correct answer from the context.
Table 1: Baseline model answers and ground truth for various queries

Training settings

Our training setup involves using a batch size of 2 with 8 gradient accumulation steps and training for a single epoch to ensure each data point is processed once. The training is conducted on a single Nvidia A100/80G GPU. As we increase the number of training samples, we retain the original training data and build upon it. Each training session however starts from scratch, with a maximum training set size of 6000 samples.

We consider three scenarios for training, based on the number of paragraphs of context we include during the training:

  • D0: The only paragraphs given to the LLM during training are the 2 that contain the supporting facts;
  • D1: A random distracting document from the 8 that do not contain supporting facts is added for a total of 3 paragraphs of context;
  • D8: All 10 paragraphs are used as context by the LLM, with all 8 distracting ones.

We will also refer to models as “Dx-y”, for example D1–3K, for the model trained with x distracting documents on y samples.

Results

We evaluate the models on 2 different test sets. One is the D0-Set, where we only feed the 2 paragraphs that contain the supporting facts to the LLM and the other is D8-Set, where we use all 10 paragraphs as context.

Answer comparison

Let us start by checking some answers from a fine-tuned model, such as D1–3K on the D8-set:

  • It got better at formatting, as you can see on lines 3 and 4 of Table 2;
  • It got better at picking up context, as you can see on lines 5 and 6;
  • It overall got a F1 score of 0.70 (baseline: 0.59) and an EM score of 0.57 (baseline: 0.42), which is a sizeable improvement.
Table 2: Answer comparison for D1–3K and Baseline models

Training time comparison

As we can see on Figure 2., the training times are linear with regards to the training set size. They also get longer as we add distracting paragraphs to the training prompts. As such, D0 and D1 are easier to train than D8 but all of the training times remain fairly short. The fine-tunings for D0–3K and D1–3K, for example, only take about 10 minutes with our setup.

Fine-tuning D1 and D8 models is respectively 1.31 and 1.75 times longer than fine-tuning D0 models.

Figure 2. Training times for every setting based on the number of training samples

Performance comparison

Finally, let’s check the performances of all our models on both sets. Overall, the D8 training setting underperforms by quite a large margin, barely keeping up with the baseline model. On the D0 set, D0 and D1 models achieve similar results and exhibit similar behaviour: the F1 score quickly rise from the baseline which is at 0.69 to around 0.74 with 1000 training samples and end up plateauing at around 0.76. Still, D0 slightly outperforms D1 between 750 and 3000 training samples.

Figure 3. Model F1 score on the D0 set based on the number of training samples
Figure 4. Model F1 score on the D8 set based on the number of training samples

However, when using all distracting paragraphs, the D1 models always perform better than other settings. While D0 models plateau at 1000 training samples, D1 models continue improving until 3000 training samples. Each setting’s best model is the 3K model, and D1–3K outperforms D0–3K by almost 4 points.

These results concur with those presented on Figure 6 of the RAFT paper, as they show that training on HotPotQA with a distracting document improves the model’s accuracy on the test sets (that contain various amount of distracting documents).

Hypotheses

Here are some of our hypotheses to explain these improvements:

  • One possible reason for this improvement could be related to how we prompt the model. Since we fine-tuned the models using the same prompts that we employed during testing, the model would improve after fine-tuning;
  • The fast increase in F1 score that happens at first most likely comes from the formatting. This metric is closely to how answers are formatted, and fine-tuning can quickly change the models’ outputs. The figures for EM score supports this hypothesis;
  • As seen on Table 2, the fine-tuned models also seem better at finding the supporting facts within the context to answer the question;
  • While more experiments are needed to explain the D8 models performances, the issue with fine-tuning with this setting might come from the sheer size of the training prompts. Mistral could have issues learning from a prompt that exceeds a certain context window size.

Overall, the model learns how to deal with this specific use case very fast.

Conclusions

Here are the main take-aways from this article:

  • Adding even a single distracting document to the context when training an LLM helps it distinguish distracting context from supporting facts during inference, thus making it more efficient. This is especially interesting in a RAG setting, where the retrieval step might bring out distracting context;
  • The number of distracting documents added during training has to be carefully picked;
  • If there won’t be any distracting context during inference, adding documents during training is unnecessary;
  • Fine-tuning Mistral-7B-Instruct for a single epoch on under 1000 training samples is very quick and can significantly improves results.

Future works

  • Repeating all our training sessions using different training samples (e.g., by changing the seed) to determine if any of the performance declines are due to statistical fluctuations or other factors;
  • Investigating the ideal number of distracting documents for training on HotPotQA;
  • Conducting the same experiments for domain-specific tasks. Since HotPotQA is based on general trivia, it does not really allow for models to improve by learning information;
  • Training with various prompt sizes (by filtering the data based on how long their supporting facts are). This would be a first step to get a better understanding at D8’s poor performances;
  • Testing our models on a different general trivia QA dataset (like Natural Question) to check whether the model learnt how to filter out unnecessary information or if it got better at this very specific task on HotPotQA;
  • Trying out other models to fine-tune, like LLaMa;
  • Implementing automatic fine-tuning from a knowledge base with Giskard’s RAGET. RAGET can generate automatically a list of question, reference_answer and reference_context from the knowledge base of a RAG. While it is initially intended for testing, we could try to repurpose it for fine-tuning on distracting context.

--

--