Fine-tuning LLMs for Enterprise RAG — A design perspective

Raunak Jain
8 min readApr 2, 2024

How to know if you are training for getting better or just burning money?

What do I mean by design perspective?

ChatGPT says: “design perspective” typically refers to the viewpoint or approach taken when creating or evaluating a design. It encompasses various factors such as functionality, usability, user experience, cultural context, and intended audience. Designers often consider these perspectives to ensure that their creations effectively communicate their intended message or fulfill their intended purpose. Additionally, design perspective can also involve considering trends, principles, and theories in design to inform decision-making and problem-solving throughout the design process.

Elaborating on the beautiful definition above, in this document we will build a framework (measurable skills, metrics, define datasets and review techniques) around RAG specific fine-tuning.

What is RAG and what are we training for?

Retrieval Augmented Generation has become a core component to effectively solve enterprise use-cases like Summarization, Question Answering, Tool Selection and other contextually dependent tasks for automation flows and agent systems. It acts as a memory (long and short) and context management layer of any LLM based pipeline. It can marry the generative and creative aspects of an LLM and the grounding provided by Knowledge Lakes to produce an accurate and aligned behavior in systems like Coding CoPilots, Customer Service agents, Agent Assist etc.

source

Challenges related to productionising and scaling RAG still exists since there are several components even within the retriever pipeline and it is very tricky to build a strong retriever and ranking framework which can be measured and trained in production.

Some problems in a RAG pipeline might be; retrieving bad context, or LLM hallucination due to the internal parametric learning and confirmation bias, or producing answers which are not useful due to mis-alignment, or the generated response is not in the expected format.

Solving for these things reliably is still an open question solved by practitioners in the wild. Let’s first wrap our head around how to measure how good a RAG pipeline performs and understand how fine-tuning can help. We will also understand why fine-tuning the generator might be enough rather than going after the retriever.

Do we improve the Retriever or the Generator, what are the tradeoffs?

Seems like a previous life! But in a world where generative models were not the only thing we talked about, domain adaptation and end to end training of the RAG architecture was a big-deal and very tricky to scale and get right for Language Models on a small dataset. See this for a presentation on this work I did related to Domain Adaptation and Contrastive Learning for encoder only / masked language modeling.

More resources:

  1. ELI5: Long Form Question Answering
  2. How to Finetune the Entire RAG Architecture (including DPR retriever)

There was an understanding that only training the retriever was not enough for optimal accuracies, but since the retriever depends on embeddings, we need to update them continuously when back-propagated, which is prohibitive to scale at for large datasets.

REALM — Retrieval-Augmented Language Model Pre-Training

Some other areas to focus on and papers from that era (some 2 years back!) in my opinion were:

  1. Learning domain specific tokenizers and embeddings.
  2. Pre-training or fine-tuning for new tasks, like training BERT for entity or intent extraction.
  3. Contrastive learning for a holistic view on learning discriminative models.
Learning domain specific tokenizers and embeddings.
Contrastive learning for a holistic view on learning discriminative models.

Enter Decoder only models

When it come to this new world of RAG — with powerful generative models, the focus has shifted from hyper-optimizing the retrieval / embedding models to improving the generative model’s capabilities to focus on the right parts of the context and synthesize information in a manner which can improve reasoning, hallucination and solution generation capabilities.

Instead of focusing on modifying the models at the fundamental token or understanding layers, we try to improve the model’s capability to focus on the correct information and reason by fine-tuning it on noisy datasets in an adversarial manner or by forcing it to behave in a desired manner by in-context learning or in a heavily task enriched manner like in AdaptLLM.

But before we dig into strategies, we need to understand how we measure improvement in these RAG systems, let’s define some capabilities we want to elicit, evaluation strategies and metrics we will work with to measure improvement and datasets and deployment patterns we want to aim for.

Some capabilities we want to develop in the generative models through the training process:

source
  1. Reasoning within the new domains, improve query augmentation, extract observations and loop through the retrieval process to behave better in new domains.
  2. Overriding parametric knowledge with contextual knowledge, rely less on pre-training knowledge when retrieval contains conflicting information.
  3. Handling longer contexts accurately, less susceptible to noise, especially in long conversations data.
  4. Coherence, relevance and groundedness of the answer generated.
  5. Safety by avoiding jail breaks, even in multi-turn.

To be able to improve on the above behaviors, we will try to define some metrics.

Some measurable metrics we should learn how to quantify:

  1. Distractibility — lack of self consistency and high sensitivity to noise.
  2. Hallucination — Hallucination in a foundation model (FM) refers to the generation of content that strays from factual reality or includes fabricated information.
  3. Misalignment or behavioural safety— paying attention to irrelevant context which is out of domain and/or answering in an unexpected tone, format.
  4. Reasoning — better thought generation for domain specific logical reasoning.
  5. Overconfidence — improving calibration and failure recognition by the LLM itself.

Some open source frameworks for evaluation

RAGAS — a popular evaluation framework for RAG applications breaks the score down into generation and retrieval, although, this is a valid break up when the system is composed of separate systems for retrieval and generation, but since here we are talking about only the generative model, we will only talk about generation related metrics.

RAGAs — https://docs.ragas.io/en/latest/concepts/metrics/index.html

Galelio has a very well defined metric definition and index for tracking hallucination —correctness and context adherence.

…to be elaborated and final metrics frozen.

In the community, there has always been discussions around Prompt Engineering and RAG solving all problems like Hallucination and Answering style, but these depend heavily on the choice of the model and prompt, hence, we will deal with an approach which is not dependent on the correct usage at run-time, rather it provides some guarantees when metrics and dataset are frozen.

Let’s look at the following to understand ideal set up for fine-tuning to improve on the above mentioned capabilities and metrics.

  1. Training strategies — do we train the decode model from scratch? Can we just train the retriever? Is reinforcement learning a better strategy?
  2. Training approaches — Domain Adaptive Pre-Training? or Parameter Efficient Domain Adaptation? In-context learning or prompt engineering? We will dig deeper into different strategies and when they work.
  3. Task and data architecture — we can train for question answering on knowledge bases, or intent detection in a conversation, or strctured data extraction from OCR. The underlying data and the task formulation matters more than we think.
  4. Data Augmentation techniques — depending on task and data architecture, we need to identify augmentation strategies to fulfil capability and metrics improvement.

Training Strategies

Training the Generator to also act as a Retriever

Source — RA-DIT: RETRIEVAL-AUGMENTED DUAL INSTRUCTION TUNING

Training the Generator to focus on the needle in the haystack

Retrieval Augmented Fine Tuning (RAFT), a new technique that optimizes LLMs for RAG on domain-specific knowledge

Our investigations revealed that many open-source models struggle with [RAG], particularly with the following issues: a) dealing with in-accurate retrievers, b) answer style mismatch and c) extracting incorrect information from the retrieved context,

Teacher — student learning

Training Approaches

To PEFT or not to PEFT, to Domain Adapt or to Task Extend…

Two trends are clear:

  1. It is not clear if PEFT is weaker than Domain Adaptive Pre-Training, but large scale task specific data generation and training is extremely beneficial.
  2. Prompt Engineering dependent strategies degrade when models are trained on new domain data. Which might be a good thing to avoid hallucination driven by bad prompt and contexts.

Approach 1 — Domain-Adaptive Pre-Training (DAPT) with well defined tasks.

Taken inspiration from human learning via reading comprehension — practice after reading improves the ability to answer questions based on the learned knowledge — we propose a simple method for transforming raw corpora into reading comprehension texts. Each raw text is enriched with a series of tasks related to its content.

ADAPTING LARGE LANGUAGE MODELS VIA READING COMPREHENSION
Improving Domain Adaptation through Extended-Text Reading Comprehension

https://arxiv.org/pdf/2005.11401.pdf

https://arxiv.org/abs/2310.01352

Approach 2 — Solving at inference

Source

Data Augmentation and Synthetic Data Generation

While training LLMs, data augmentation is needed in low — resource and/or adversarial settings where either; we do not have enough raw data or application specific data for which we are trying to optimize. Some common generation strategies:

Multi task instruction tuning — for e.g. Reading Comprehension like in AdaptLLM:

Adversarial

Contrastive

Rule based

--

--