Improving RAG using LangGraph and LangChain

Introduction to LangGraph for multi-agent environments

Mehul Gupta
Data Science in your pocket

--

I’ve already discussed a lot of LangChain demos and concepts in my previous post. So much so that I even wrote a book.

LangGraph is the latest addition to the family of LangChain, LangServe & LangSmith revolving around building Generative AI applications using LLMs. Remember that all these are separate packages and must be pip-installed separately.

Before we jump on LangGraph, we need to understand LangChain’s two major concepts.

  1. Chains: Programs written around LLMs to execute a task say auto SQL writing or NER extraction chains. Do note that chains can’t be used for any other task (not even general use cases) and may break if try doing so. The steps to be followed in chains are pre-defined and are not flexible.
  2. Agents: A much more flexible version of chains, agents are usually LLMs enabled with 3rd party tools (say Google search, YouTube) and the LLM itself decide what to do next to solve the given query.

You can read more about chains & agents in the dedicated chapters in the book

Now, a common issue is when working with real-world problems, you wish to have a solution that is somewhere in between chains & agents. i.e. not as hard coded as chains but not fully driven by LLMs as incase of agents.

LangGraph

LangGraph, using LangChain at the core, helps in creating cyclic graphs in workflows. So, assume this example:

You wish to build a RAG based retrieval system over your knowledge base. Now, you wish to introduce such a case that if the RAG output is not meeting a particular quality, the agent/chain should again retrieve data, but with a changed prompt this time on its own. And repeat this until the quality threshold is matched.

Such a cyclic logic can be implemented using LangGraph. And this is just an example, a lot more can be done using LangGraph.

Note: It can be taken as introducing cyclic logic to chains, making them cyclic chains.

LangGraph can be crucial for building Multi-Agent applications like Autogen or MetaGPT

LangGraph, as the name suggests, has all the components a general graph has like nodes, edges, etc. Let’s understand with an example:

Improving RAG using LangGraph

In this example, I wish to reduce the final output from my RAG system over a database to be less than 30 characters. If the output length is greater than 30, I wish to introduce a cycle use a different prompt and try again until the length is less than 30. This is a baseline logic for demonstration purposes. You can even implement complex logic to improve RAG results.

The graph we would be creating looks like this

The versions used here are langchain==0.0.349, openai==1.3.8, langgraph==0.0.26

  1. First, let’s import the important stuff and initialize your LLM. I’m using OpenAI API but you can use other LLMs as well.
from typing import Dict, TypedDict, Optional
from langgraph.graph import StateGraph, END
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.llms import OpenAI
from langchain.embeddings.openai import OpenAIEmbeddings

llm = OpenAI(openai_api_key='your API')

Next, we will define a StateGraph.

class GraphState(TypedDict):
question: Optional[str] = None
classification: Optional[str] = None
response: Optional[str] = None
length: Optional[int] = None
greeting: Optional[str] = None

workflow = StateGraph(GraphState)

What’s a StateGraph?

The heart of any LangGraph flow, StateGraph stores the state of various variables we would be storing while executing the workflow. In this case, we have 5 variables whose values we would be updating while executing the graph and would be shared with all edges and nodes.

2. Next, let’s initialize a RAG retrieval chain from an existing vector DB. The codes are already explained in the video below

def retriever_qa_creation():
embeddings = OpenAIEmbeddings()
db = Chroma(embedding_function=embeddings,persist_directory='/database',collection_name='details')
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=db.as_retriever())
return qa

rag_chain = retriever_qa_creation()

3. Next, we will add nodes to this Graph

def classify(question):
return llm("classify intent of given input as greeting or not_greeting. Output just the class.Input:{}".format(question)).strip()

def classify_input_node(state):
question = state.get('question', '').strip()
classification = classify(question)
return {"classification": classification}

def handle_greeting_node(state):
return {"greeting": "Hello! How can I help you today?"}

def handle_RAG(state):
question = state.get('question', '').strip()
prompt = question
if state.get("length")<30:
search_result = rag_chain.run(prompt)
else:
search_result = rag_chain.run(prompt+'. Return total count only.')

return {"response": search_result,"length":len(search_result)}


def bye(state):
return{"greeting":"The graph has finished"}

workflow.add_node("classify_input", classify_input_node)
workflow.add_node("handle_greeting", handle_greeting_node)
workflow.add_node("handle_RAG", handle_RAG)
workflow.add_node("bye", bye)

This needs some explanation

  • Every node is a Python function which can

Read any state variable.

Update any state variable. In this case, the return function for each node is updating state/value of some or the other state variable.

  • The state.get() is used to read any state variable
  • The handle_RAG node helps us in implementing the custom logic we wish for the cycle. If the length of the output is < 30, then use prompt A or else B. For the 1st case (when the RAG node is not executed yet), we will be passing length=0 while providing a prompt.

4. Next, we will be adding entry points and edges

workflow.set_entry_point("classify_input")
workflow.add_edge('handle_greeting', END)
workflow.add_edge('bye', END)

In the above code snippet,

  • We added an entry point to the graph i.e. the first node function to execute irrespective of the input prompt.
  • Edges between A & B nodes define node B should be executed after node A. In this case, if, in our workflow, “handle_greeting” or “bye” comes, the graph should END (a special node to terminate the workflow)

5. Next, let's add conditional edges

def decide_next_node(state):
return "handle_greeting" if state.get('classification') == "greeting" else "handle_RAG"

def check_RAG_length(state):
return "handle_RAG" if state.get("length")>30 else "bye"

workflow.add_conditional_edges(
"classify_input",
decide_next_node,
{
"handle_greeting": "handle_greeting",
"handle_RAG": "handle_RAG"
}
)

workflow.add_conditional_edges(
"handle_RAG",
check_RAG_length,
{
"bye": "bye",
"handle_RAG": "handle_RAG"
}
)

A conditional edge helps to choose between 2 nodes depending upon a condition (say if-else). In the 2 conditional edges created:

1st conditional edge

Onec “classifiy_input” is encountered, choose either “handle_greeting” or “handle_RAG” depending upon the output of decide_next_node function

2nd conditional edge

If “handle_RAG” is encountered, choose either “handle_RAG” or “bye” depending upon check_RAG_length.

6. Compile and invoke for a prompt. Keeping length variable=0 initially

app = workflow.compile()
app.invoke({'question':'Mehul developed which projects?','length':0})
#output 
{'question': 'Mehul developed which projects?',
'classification': 'not_greeting',
'response': ' 4',
'length': 2,
'greeting': 'The graph has finished'}

The graph flow looks something like this for the above prompt

classify_input : The sentiment would be not_greeting

Due to 1st conditional_edge, moves to handle_RAG

As length=0, use 1st prompt and retrieve answer (total length would be>30)

Due to 2nd condtional_edge, moves again to handle_RAG

As length>30, use 2nd prompt

Due to 2nd conditional_edge, moves to bye

END

If LangGraph wasn’t used

rag_chain.run("Mehul developed which projects?")

#output
"Mehul developed projects like ABC, XYZ, QWERTY. Not only these, he has major contribution in many other projects as well at OOO organization"

7. Next input

app.invoke({'question':'Hello bot','length':0})

#output
{'question': 'Hello bot',
'classification': 'greeting',
'response': None,
'length': 0,
'greeting': 'Hello! How can I help you today?'}

The flow here would be simpler

classify_input : The sentiment would be “greeting”

Due to 1st conditional_edge, moves to handle_greeting

END

Though the condition I applied here is quite naive, this framework can be easily used to improve your results by adding more complex conditions.

See you soon !!

--

--