MedGraphRAG: A New Revolution in AI Performance in Medicine
Explore how RAG, GraphRAG, and MedGraphRAG revolutionize AI by enhancing the accuracy and performance of LLM responses in the medical field.
LLMs have transformed how humans search for information for everyday tasks.Although well suited for general scenarios, LLMs hallucinate and produce irrelevant information when queried on specialized knowledge such as medicine, law, finance, and more.
They also do not give up-to-date information in these constantly updating fields and offer simplistic responses without considering novel insights or discoveries.
LLMs can also not access out-of-the-box private and specialized data related to a field unless they are fine-tuned in it. However, fine-tuning is a complex process involving domain expertise, considerable time, and computational resources.
Introduction To MedGraph
To combat this issue, Retrieval Augmented Generation (RAG) was introduced in 2021. This method lets LLMs answer user queries using specialized private datasets without requiring any fine-tuning.
The process was made even more accurate in early 2024 using Graph Retrieval-Augmented Generation (GRAG).
Finally, we have MedGraphRAG, a novel Graph-based Retrieval-Augmented Generation (RAG) framework specifically designed for the medical domain.This method consistently outperforms the state-of-the-art LLMs (even when fine-tuned) on multiple medical Q&A benchmarks.
It also avoids the ‘black-box’ approach to LLM response generation by ensuring that the responses include source documentation, which is absolutely necessary in Medicine where hallucinated responses might cost lives.
GPT-4 answers a medical question wrong (Image from the original research paper)For the same question, MedGraphRAG generates evidence-based responses with grounded citations and terminology explanations (Image from the original research paper)
Here is a story where we deep-dive into how RAG, GraphRAG, and MedGraphRAG work and significantly improve the performance of LLM responses in specialized domains.
Let’s go!
Let’s Start With RAG
RAG, or Retrieval Augmented Generation (RAG), is an information retrieval technique that allows an LLM to produce more accurate and up-to-date responses using private datasets specific to the use case.
RAG Visualised (Image from author’s book ‘AI In 100 Images’)
The terms in RAG mean the following:
- Retrieval: the process of retrieving relevant information/ documents from a knowledge base/ specific private datasets.
- Augmentation: the process where the retrieved information is added to the input context.
- Generation: the process of the LLM generating a response based on the original query and the augmented context.
Although RAG is very helpful, it sometimes struggles to connect information based on shared attributes.
Its performance also suffers in tasks that require a deep understanding of summarized semantic concepts across large datasets.
To deal with these limitations, Graph RAG was introduced in 2024.
Let’s talk about it next.
What Is Graph RAG?
Graph RAG or Graph Retrieval-Augmented Generation (GRAG) extends RAG by incorporating Knowledge graphs.
While RAG inherently neglects the topological relationships in the textual information, GraphRAG makes using it possible.
GRAG’s core workflow has four stages:
- Indexing of k-hop ego-graphs: This step involves creating searchable subgraphs centred around each node (called ego) that include all connected nodes within
k
steps from the ego. - Graph retrieval: This step retrieves the most relevant ego-graphs for a given query from the indexed subgraphs.
- Soft pruning: This step removes the irrelevant entities in retrieved subgraphs to reduce their impact on the generation process.
- Generation with pruned textual subgraphs: This step involves text generation from the LLM using the pruned subgraphs.
The Graph RAG Workflow. Given a query and a related textual graph: (a) all k-hop ego-graphs (e.g., 1-hop here) are embedded into graph embeddings and compared with the query vector to retrieve the top-N similar subgraphs. Irrelevant entities in the graph are partially masked using a soft pruning module. (b) For the final generation, pruned ego-graphs are encoded into a soft graph token, and textual information is encoded into text tokens. © Soft pruning module. (d) Generating text descriptions of ego-graphs preserving both textual and topological information. (Image from the research paper titled ‘GRAG: GraphRetrieval-Augmented Generation’ published in ArXiv)
GRAG retrieves subgraphs relevant to the query rather than discrete documents like RAG.
This reduces the negative impact of semantically similar but irrelevant documents (shown in red in the image below) on the generation.
GraphRAG retrieval approach (Image from the research paper titled ‘GRAG: GraphRetrieval-Augmented Generation’ published in ArXiv)
These concepts from Graph RAG are then further extended to the medical domain, and that’s how we get MedGraphRAG.
Let’s discuss how it works next.
How Does MedGraphRAG Work?
The workflow MedGraphRAG or Medical Graph RAG can be described simply in three steps as follows:
- Medical Graph Construction
- Segmenting medical documents into chunks
- Extracting relevant entities from these chunks
- Organizing them into a three-tier graph structure that links these entities
2. Graph Retrieval
- Given a user query, retrieving relevant graphs and entities
3. Text Generation
- Generating text using this retrieved information, along with citations to source documents
Let’s learn about each of these steps in more detail.
Semantic Document Segmentation
Given a knowledge base or private dataset, this step segments its documents into chunks.
The traditional RAG approach involves chunking based on token size or fixed characters. However, these approaches lead to a loss of semantic information since subtle shifts in topics are not well detected.
To fix this, a different chunking approach is used.
Firstly, the line break symbols are used to isolate different paragraphs in a document.
Each paragraph is then converted into self-sustaining statements/ Propositions using a Semantic segmentation technique called Proposition transfer, as described here.
Next, an LLM analyzes each proposition sequentially to determine whether it should be merged with existing chunks or started from scratch.
This process is performed five paragraphs at a time using a sliding window technique to reduce noise.
A hard threshold is also set so the longest chunk does not exceed LLM’s context length limitation.
These steps divide the documents into meaningful chunks, on which graphs are constructed later.
Element Extraction
This step involves identifying and extracting relevant entities (nodes) from each chunk.
An LLM is prompted to output each entity’s name, type (from a predefined list of professional medical terminology), and description from each chunk.
This extraction process is repeated multiple times to reduce noise and ensure completeness and quality.
Each extracted entity is also given a unique ID to trace its source document and paragraph.
Hierarchy Linking
This step ensures that the LLMs do not distort or add irrelevant information apart from the precise medical terminology.
This is done by linking each extracted entity to grounded medical knowledge and terms by constructing a three-tiered Graph RAG data structure.
Its top/ first level involves entity extraction from private user-provided documents.
Researchers use the MIMIC-IV dataset, a publicly available electronic health record dataset for this level, for their experiments.
The second level is constructed by linking these entities with graphs of foundational medical knowledge created from textbooks and scholarly articles.
The MedC-K corpus, containing 4.8 million biomedical papers and 30,000 textbooks, is used at this level.
In the third level, the second-level graphs are further connected to well-established medical terms from reliable resources like the Unified Medical Language System (UMLS).
Relationship Linking
This step involves creating weighted directed graphs (called Meta-graphs) using an LLM to identify all relationships between clearly related entities.
This enhances the richness of the graph structure.
Tags Generation and Merging of Graphs
This next step links all the Meta-graphs to generate a Global graph that can be used for efficient information retrieval for medical queries.
Firstly, an LLM generates summaries of the meta-graphs based on predefined medical categories (such as symptoms, patient history, body functions, and medications).
This results in a list of tags that succinctly describe the meta-graph’s central themes.
Using these tags, the similarity between different meta-graphs is calculated, which guides their merging into a single global graph.
A summary of all the above steps is shown in the image below.
The MedGraphRAG framework (Image from the original research paper)
Graph Retrieval
This step involves an LLM retrieving information from the global graph to respond to user queries.
Summarised tags are first generated for each user query.
These are then used to identify the most relevant sections of the graph.
This is done using a top-down matching approach called U-retrieve, where matching starts from the larger graphs and moves towards the smaller ones.
This finds the relevant entities in the graph along with their top-k
related entities to answer a user query.
Text Generation
This step involves an LLM generating an intermediate response form from the retrieved information.
This intermediate response is combined with the summarized tag information of the higher-level graphs in a bottom-up manner to generate a final response after scanning all the indexed graphs along the trajectory.
Performance Of MedGraphRAG
MedGraphRAG significantly improves the performance of different LLMs on multiple medical benchmarks (PubMedQA, MedMCQA, and USMLE datasets).
Substantial performance enhancement is noted for smaller models like LLaMA2–13B and LLaMA3–8B, which typically underperform on these benchmarks on their own.
MedGraphRAG improves the accuracy of GPT-4, resulting in SOTA performance on the MedQA USMLE benchmark (Image from the original research paper)
The method also improves the accuracy of GPT and LLaMA3–70B, leading to state-of-the-art (SOTA) results that even surpass human expert accuracy in clinical workflows.
Note that MedGraphRAG even outperforms the fine-tuned models in the medical domain.
Performance of MedGraphRAG enhanced LLMs (Image from the original research paper)
Ablation studies on the method show that document chunking, hierarchical graph construction, and U-retrieve, although complex, significantly enhances its performance.
Results of Ablation Studies on MedGraphRAG support the importance of using sophisticated data processing and retrieval techniques (Image from the original research paper)
Finally, MedGraphRAG also enables LLMs to generate evidence-based responses to complex medical questions by listing their sources.
MedGraphRAG is a big step towards improving the safety and explainability of LLMs in medicine.
Research like this gives me hope that the extensive use of AI in medicine is not far off.