Build a Reliable RAG Agent using LangGraph

Plaban Nayak
The AI Forum
Published in
13 min readMay 4, 2024
RAG Agent Workflow

Introduction

Here we will build reliable RAG agents using LangGraph, Groq-Llama-3 and Chroma, We will combine the below concepts to build the RAG Agent.

  • Adaptive RAG (paper). We have implemented the concept described in this paper build a Router for routing questions to different retrieval approaches
  • Corrective RAG (paper). We have implemented the concept described in this paper to develop a fallback mechanism to progress with when the context retrieved is irrelevant to the question asked.
  • Self-RAG (paper). We have implemented the concept described in this paper to develop a hallucination grader .i.e. fix answers that hallucinate or doesn’t address the question asked.

What is an Agent ?

The fundamental concept behind agents involves employing a language model to select a sequence of actions. While in chains, this sequence is hardcoded within the code. Conversely, agents utilize a language model as a reasoning engine to decide the actions to take and their order.

It comprises of 3 components:

  • planning : breaking tasks into smaller sub-goals
  • memory : short term(chat history) / long term(vectorstore)
  • Tool Use : It can make uise of different tools to extend it’s capabilities.

Agents can me manifested by using ReAct concept with Langchain or by using LangGraph,

Tradeoffs between Langchain Agent and LangGraph:

Reliability

  • ReAct / Langchain Agent: Less Reliable as LLM has to make the correct decision at each step
  • LangGraph : More Reliable as control flow is set and LLM has a specific job to perform at each node

Flexibility

  • ReAct /Langchain Agent : More Flexible as LLM can choose any sequence of actions
  • LangGraph : Less Flexible as actions is constrained by setting up the control flow at each node

Compatibility with smaller LLMs

  • ReAct /Langchain Agent : Worse Compatibility
  • LangGraph : Better Compatibility

Here we have used LangGraph to create the Agent.

🦜️🔗 What is Langchain ?

LangChain is a framework for developing applications powered by language models. It enables applications that:

  • Are context-aware: connect a language model to sources of context (prompt instructions, few shot examples, content to ground its response in, etc.)
  • Reason: rely on a language model to reason (about how to answer based on provided context, what actions to take, etc.)

What is LangGraph ?

LangGraph is a library that extends LangChain, providing cyclic computational capabilities for LLM applications. While LangChain supports defining computation chains (Directed Acyclic Graphs or DAGs), LangGraph enables the incorporation of cycles. This allows for more intricate, agent-like behaviors, where an LLM can be called in a loop to determine the next action to take.

Key Concepts

  • Stateful Graph: LangGraph revolves around the concept of a stateful graph, where each node in the graph represents a step in our computation, and the graph maintains a state that is passed around and updated as the computation progresses.
  • Nodes: Nodes are the building blocks of your LangGraph. Each node represents a function or a computation step. We define nodes to perform specific tasks, such as processing input, making decisions, or interacting with external APIs.
  • Edges: Edges connect the nodes in your graph, defining the flow of computation. LangGraph supports conditional edges, allowing you to dynamically determine the next node to execute based on the current state of the graph.

Steps involved in creating a graph using LangGraph:

  1. Define the Graph State : This represents the state of the graph.
  2. Create the Graph.
  3. Define the Nodes: Here we define the different functions associated with each workflow state.
  4. Add nodes to the Graph: Here add our nodes to the graph and define the flow using edges and conditional edges.
  5. Set Entry and End Points of the Graph.

What is Tavily Search API ?

Tavily Search API is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results. Unlike other search APIs such as Serp or Google, Tavily focuses on optimizing search for AI developers and autonomous AI agents.

What is Groq?

Groq offers high-performance AI models & API access for developers with faster inference and at lower cost than competitors.

Models Supported

What is FastEmbed ?

FastEmbed is a lightweight, fast, Python library built for embedding generation.

Supported Embedding Models

Workflow of the RAG Agent

  1. Based on the question the Router decides whether to direct the question to retrieve context from vectorstore or perform a web search.
  2. If the Router decides the question ha to be directed for retrieval from vectorstore, then matching documents are retrieved from the vectorstore otherwise perform a web search using tavily-api search
  3. The document grader then grades the documents as relevant or irrelevant.
  4. If the context retrieved is graded as relevant then check for hallucination using the hallucination grader. If the grader decides the response to be devoid of hallucination then present the response top the user.
  5. In case the context is graded as irrelevant then perform a web search to retrieve the content.
  6. Post retrieval , the document grader grades the content generated from web search .If found relevant the response is synthesized by using LLM and then presented

Technology Stack Used

  • Embedding Model : BAAI/bge-base-en-v1.5
  • LLM : Llam-3–8B
  • Vectorstore : Chroma
  • Graph /Agent : LangGraph

Code Implementation

Install required libraries

! pip install -U langchain-nomic langchain_community tiktoken langchainhub chromadb langchain langgraph tavily-python gpt4all fastembed langchain-groq 

Import required libraries

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings

Instantiate the Embedding Model

embed_model = FastEmbedEmbeddings(model_name="BAAI/bge-base-en-v1.5")

Instantiate LLM

from groq import Groq
from langchain_groq import ChatGroq
from google.colab import userdata

llm = ChatGroq(temperature=0,
model_name="Llama3-8b-8192",
api_key=userdata.get("GROQ_API_KEY"),)

Download the data

urls = [
"https://lilianweng.github.io/posts/2023-06-23-agent/",
"https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
"https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]
print(f"len of documents :{len(docs_list)}")

Chunk the Documents to be in sync with the context window of the LLM

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=512, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)
print(f"length of document chunks generated :{len(doc_splits)}")

Load the documents to vectorstore

vectorstore = Chroma.from_documents(documents=doc_splits,
embedding=embed_model,
collection_name="local-rag")

Instantiate the retriever

retriever = vectorstore.as_retriever(search_kwargs={"k":2})

Implement the Router

import time
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.output_parsers import StrOutputParser

prompt = PromptTemplate(
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an expert at routing a
user question to a vectorstore or web search. Use the vectorstore for questions on LLM agents,
prompt engineering, and adversarial attacks. You do not need to be stringent with the keywords
in the question related to these topics. Otherwise, use web-search. Give a binary choice 'web_search'
or 'vectorstore' based on the question. Return the a JSON with a single key 'datasource' and
no premable or explaination. Question to route: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
input_variables=["question"],
)
start = time.time()
question_router = prompt | llm | JsonOutputParser()
#
question = "llm agent memory"
print(question_router.invoke({"question": question}))
end = time.time()
print(f"The time required to generate response by Router Chain in seconds:{end - start}")

#############################RESPONSE ###############################
{'datasource': 'vectorstore'}
The time required to generate response by Router Chain in seconds:0.34175705909729004

Implement the Generate Chain

prompt = PromptTemplate(
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know.
Use three sentences maximum and keep the answer concise <|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Context: {context}
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
input_variables=["question", "document"],
)

# Post-processing
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

# Chain
start = time.time()
rag_chain = prompt | llm | StrOutputParser()

#############################RESPONSE##############################
The time required to generate response by the generation chain in seconds:1.0384225845336914
The agent memory in the context of LLM-powered autonomous agents refers to the ability of the agent to learn from its past experiences and adapt to new situations.

Implement the Retrieval Grader

#
prompt = PromptTemplate(
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance
of a retrieved document to a user question. If the document contains keywords related to the user question,
grade it as relevant. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
Provide the binary score as a JSON with a single key 'score' and no premable or explaination.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Here is the retrieved document: \n\n {document} \n\n
Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
""",
input_variables=["question", "document"],
)
start = time.time()
retrieval_grader = prompt | llm | JsonOutputParser()
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
print(retrieval_grader.invoke({"question": question, "document": doc_txt}))
end = time.time()
print(f"The time required to generate response by the retrieval grader in seconds:{end - start}")

############################RESPONSE ###############################
{'score': 'yes'}
The time required to generate response by the retrieval grader in seconds:0.8115921020507812

Implement the hallucination grader

# Prompt
prompt = PromptTemplate(
template=""" <|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether
an answer is grounded in / supported by a set of facts. Give a binary 'yes' or 'no' score to indicate
whether the answer is grounded in / supported by a set of facts. Provide the binary score as a JSON with a
single key 'score' and no preamble or explanation. <|eot_id|><|start_header_id|>user<|end_header_id|>
Here are the facts:
\n ------- \n
{documents}
\n ------- \n
Here is the answer: {generation} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
input_variables=["generation", "documents"],
)
start = time.time()
hallucination_grader = prompt | llm | JsonOutputParser()
hallucination_grader_response = hallucination_grader.invoke({"documents": docs, "generation": generation})
end = time.time()
print(f"The time required to generate response by the generation chain in seconds:{end - start}")
print(hallucination_grader_response)

####################################RESPONSE#################################
The time required to generate response by the generation chain in seconds:1.020448923110962
{'score': 'yes'}

Implement the Answer Grader

# Prompt
prompt = PromptTemplate(
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing whether an
answer is useful to resolve a question. Give a binary score 'yes' or 'no' to indicate whether the answer is
useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.
<|eot_id|><|start_header_id|>user<|end_header_id|> Here is the answer:
\n ------- \n
{generation}
\n ------- \n
Here is the question: {question} <|eot_id|><|start_header_id|>assistant<|end_header_id|>""",
input_variables=["generation", "question"],
)
start = time.time()
answer_grader = prompt | llm | JsonOutputParser()
answer_grader_response = answer_grader.invoke({"question": question,"generation": generation})
end = time.time()
print(f"The time required to generate response by the answer grader in seconds:{end - start}")
print(answer_grader_response)

##############################RESPONSE###############################
The time required to generate response by the answer grader in seconds:0.2455885410308838
{'score': 'yes'}

Implement Web Search tool

import os
from langchain_community.tools.tavily_search import TavilySearchResults
os.environ['TAVILY_API_KEY'] = "YOUR API KEY"
web_search_tool = TavilySearchResults(k=3)

Define the Graph State : represents state of the graph.

Define the following attributes:

  • question
  • generation : LLM Generation
  • web_search : whether to add search
  • documents : list of documents
from typing_extensions import TypedDict
from typing import List

### State

class GraphState(TypedDict):
question : str
generation : str
web_search : str
documents : List[str]

Define the Nodes

from langchain.schema import Document
def retrieve(state):
"""
Retrieve documents from vectorstore

Args:
state (dict): The current graph state

Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
question = state["question"]

# Retrieval
documents = retriever.invoke(question)
return {"documents": documents, "question": question}
#
def generate(state):
"""
Generate answer using RAG on retrieved documents

Args:
state (dict): The current graph state

Returns:
state (dict): New key added to state, generation, that contains LLM generation
"""
print("---GENERATE---")
question = state["question"]
documents = state["documents"]

# RAG generation
generation = rag_chain.invoke({"context": documents, "question": question})
return {"documents": documents, "question": question, "generation": generation}
#
def grade_documents(state):
"""
Determines whether the retrieved documents are relevant to the question
If any document is not relevant, we will set a flag to run web search

Args:
state (dict): The current graph state

Returns:
state (dict): Filtered out irrelevant documents and updated web_search state
"""

print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
question = state["question"]
documents = state["documents"]

# Score each doc
filtered_docs = []
web_search = "No"
for d in documents:
score = retrieval_grader.invoke({"question": question, "document": d.page_content})
grade = score['score']
# Document relevant
if grade.lower() == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(d)
# Document not relevant
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
# We do not include the document in filtered_docs
# We set a flag to indicate that we want to run web search
web_search = "Yes"
continue
return {"documents": filtered_docs, "question": question, "web_search": web_search}
#
def web_search(state):
"""
Web search based based on the question

Args:
state (dict): The current graph state

Returns:
state (dict): Appended web results to documents
"""

print("---WEB SEARCH---")
question = state["question"]
documents = state["documents"]

# Web search
docs = web_search_tool.invoke({"query": question})
web_results = "\n".join([d["content"] for d in docs])
web_results = Document(page_content=web_results)
if documents is not None:
documents.append(web_results)
else:
documents = [web_results]
return {"documents": documents, "question": question}
#

Define Conditional Edges

def route_question(state):
"""
Route question to web search or RAG.

Args:
state (dict): The current graph state

Returns:
str: Next node to call
"""

print("---ROUTE QUESTION---")
question = state["question"]
print(question)
source = question_router.invoke({"question": question})
print(source)
print(source['datasource'])
if source['datasource'] == 'web_search':
print("---ROUTE QUESTION TO WEB SEARCH---")
return "websearch"
elif source['datasource'] == 'vectorstore':
print("---ROUTE QUESTION TO RAG---")
return "vectorstore"
def decide_to_generate(state):
"""
Determines whether to generate an answer, or add web search

Args:
state (dict): The current graph state

Returns:
str: Binary decision for next node to call
"""

print("---ASSESS GRADED DOCUMENTS---")
question = state["question"]
web_search = state["web_search"]
filtered_documents = state["documents"]

if web_search == "Yes":
# All documents have been filtered check_relevance
# We will re-generate a new query
print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
return "websearch"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "generate"
def grade_generation_v_documents_and_question(state):
"""
Determines whether the generation is grounded in the document and answers question.

Args:
state (dict): The current graph state

Returns:
str: Decision for next node to call
"""

print("---CHECK HALLUCINATIONS---")
question = state["question"]
documents = state["documents"]
generation = state["generation"]

score = hallucination_grader.invoke({"documents": documents, "generation": generation})
grade = score['score']

# Check hallucination
if grade == "yes":
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
# Check question-answering
print("---GRADE GENERATION vs QUESTION---")
score = answer_grader.invoke({"question": question,"generation": generation})
grade = score['score']
if grade == "yes":
print("---DECISION: GENERATION ADDRESSES QUESTION---")
return "useful"
else:
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
return "not useful"
else:
pprint("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
return "not supported"

Add nodes

from langgraph.graph import END, StateGraph
workflow = StateGraph(GraphState)

# Define the nodes
workflow.add_node("websearch", web_search) # web search
workflow.add_node("retrieve", retrieve) # retrieve
workflow.add_node("grade_documents", grade_documents) # grade documents
workflow.add_node("generate", generate) # generatae

Set the Entry Point and End Point

workflow.set_conditional_entry_point(
route_question,
{
"websearch": "websearch",
"vectorstore": "retrieve",
},
)

workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
decide_to_generate,
{
"websearch": "websearch",
"generate": "generate",
},
)
workflow.add_edge("websearch", "generate")
workflow.add_conditional_edges(
"generate",
grade_generation_v_documents_and_question,
{
"not supported": "generate",
"useful": END,
"not useful": "websearch",
},
)

Compile the workflow

app = workflow.compile()

Test the Workflow

from pprint import pprint
inputs = {"question": "What is prompt engineering?"}
for output in app.stream(inputs):
for key, value in output.items():
pprint(f"Finished running: {key}:")
pprint(value["generation"])

########################RESPONSE##############################
---ROUTE QUESTION---
What is prompt engineering?
{'datasource': 'vectorstore'}
vectorstore
---ROUTE QUESTION TO RAG---
---RETRIEVE---
'Finished running: retrieve:'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: GENERATE---
'Finished running: grade_documents:'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
'Finished running: generate:'
('Prompt engineering refers to methods for communicating with large language '
'models to steer their behavior for desired outcomes without updating the '
'model weights. It is an empirical science that requires heavy '
'experimentation and heuristics.')

Test The Workflow for a different question

app = workflow.compile()

# Test
from pprint import pprint
inputs = {"question": "Who are the Bears expected to draft first in the NFL draft?"}
for output in app.stream(inputs):
for key, value in output.items():
pprint(f"Finished running: {key}:")
pprint(value["generation"])

#############################RESPONSE##############################
---ROUTE QUESTION---
Who are the Bears expected to draft first in the NFL draft?
{'datasource': 'web_search'}
web_search
---ROUTE QUESTION TO WEB SEARCH---
---WEB SEARCH---
'Finished running: websearch:'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
'Finished running: generate:'
('According to the provided context, the Chicago Bears are expected to take '
'USC quarterback Caleb Williams with the first overall pick in the NFL draft.')

Test the worflow for a different Question

app = workflow.compile()

#
inputs = {"question": "What are the types of agent memory?"}
for output in app.stream(inputs):
for key, value in output.items():
pprint(f"Finished running: {key}:")
pprint(value["generation"])

###########################RESPONSE############################
---ROUTE QUESTION---
What are the types of agent memory?
{'datasource': 'vectorstore'}
vectorstore
---ROUTE QUESTION TO RAG---
---RETRIEVE---
'Finished running: retrieve:'
---CHECK DOCUMENT RELEVANCE TO QUESTION---
---GRADE: DOCUMENT RELEVANT---
---GRADE: DOCUMENT NOT RELEVANT---
---ASSESS GRADED DOCUMENTS---
---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---
'Finished running: grade_documents:'
---WEB SEARCH---
'Finished running: websearch:'
---GENERATE---
---CHECK HALLUCINATIONS---
---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---
---GRADE GENERATION vs QUESTION---
---DECISION: GENERATION ADDRESSES QUESTION---
'Finished running: generate:'
('The text mentions the following types of agent memory:\n'
'\n'
'1. Short-Term Memory (STM) or Working Memory: It stores information that the '
'agent is currently aware of and needed to carry out complex cognitive '
'tasks.\n'
'2. Long-Term Memory (LTM): It can store information for a remarkably long '
'time, ranging from a few days to decades, with an essentially unlimited '
'storage capacity.')

Visualize the Agent / Graph

!apt-get install python3-dev graphviz libgraphviz-dev pkg-config
!pip install pygraphviz
from IPython.display import Image

Image(app.get_graph().draw_png())

Conclusion

LangGraph is a flexible tool designed for constructing intricate, stateful applications using LLMs. Beginners can harness its capabilities for their projects by grasping its fundamental principles and engaging with basic examples. It’s important to focus on managing states, handling conditional edges, and ensuring there are no dead-end nodes in the graph.
In my opinion it i more beneficial when compared to ReAct agents as here we can layout a complete control on the workflow instead of agent making the decision.

Referrences

connect with me

--

--