Make an RAG with SpringBoot
using LangChain to build a Rag Application
What is RAG ?
Retrieval-Augmented Generation (RAG) is a machine learning approach that combines two key techniques: retrieval and generation. This method is particularly effective in natural language processing tasks, such as conversational agents and question-answering systems.
Key Components of RAG
- Retrieval:
- RAG first retrieves relevant documents or data from a large dataset or knowledge base based on a user’s query.
- This is typically done using information retrieval techniques, such as vector search or keyword matching.
2. Generation:
- After retrieving the relevant information, the model generates a response or answer, leveraging both the retrieved content and the original query.
- This is typically achieved using a generative model, like GPT (Generative Pre-trained Transformer) or similar architectures.
How RAG Works
- Input: A user submits a query or question.
- Retrieval Phase:
- The system searches a large corpus of documents to find the most relevant pieces of information related to the query.
3 Generation Phase:
- The retrieved information is then fed into a generative model, which synthesizes a coherent response that incorporates both the query and the retrieved content.
4. Output: The final response is returned to the user.
Advantages of RAG
- Enhanced Accuracy: By combining retrieval with generation, RAG can provide more accurate and contextually relevant answers.
- Knowledge Utilization: It can leverage a vast amount of external knowledge, making it suitable for complex queries where simple generative models might fail.
- Dynamic Responses: The system can generate answers that are not only based on pre-defined responses but also tailored to specific user queries and retrieved information.
Applications
- Chatbots: RAG can be used to create conversational agents that provide accurate and contextually relevant answers.
- Question Answering Systems: It excels in systems that need to answer questions based on large datasets, such as in customer support or educational platforms.
- Content Creation: RAG can assist in generating content by retrieving relevant information and crafting it into cohesive narratives.
In summary, RAG is a powerful approach that enhances the capabilities of AI systems by combining the strengths of information retrieval and text generation.
Why is retrieval-enhanced generation important?
LLM is a key artificial intelligence (AI) technology that powers intelligent chatbots and other natural language processing (NLP) applications. The goal is to create bots that can answer user questions in a variety of contexts by cross-referencing authoritative knowledge sources. Unfortunately, the nature of LLM technology introduces unpredictability in LLM responses. Additionally, LLM training data is static and introduces a deadline for the knowledge it possesses. Known challenges with LLM include:
- Providing false information without an answer.
- Providing outdated or generic information when users need a specific, current response.
- Create a response from a non-authoritative source.
- Because of terminology confusion, different training sources use the same term to talk about different things, which can result in inaccurate responses.
You can think of a large language model as an overly enthusiastic new employee who refuses to stay up to date on current events, but always answers every question with absolute confidence. Unfortunately, this attitude can negatively impact user trust, which is something you don’t want your chatbot to emulate! RAG is one way to address some of these challenges. It redirects the LLM to retrieve relevant information from authoritative, predetermined knowledge sources. Organizations have more control over the generated text output, and users gain insight into how the LLM generated responses.
How does retrieval-enhanced generation work?
Without RAG, the LLM takes user input and creates responses based on the information it was trained on or what it already knows. RAG introduces an information retrieval component that uses user input to first extract information from new data sources. Both the user query and the related information are provided to the LLM. The LLM uses the new knowledge and its training data to create better responses. The following sections outline the process.
Creating External Data
New data outside the original training dataset of LLM is called external data . It can come from multiple data sources, such as APIs, databases, or document repositories. Data may exist in various formats, such as files, database records, or long texts. Another AI technique called embedded language model converts data into a digital representation and stores it in a vector database. This process creates a knowledge base that a generative AI model can understand.
Retrieve relevant information
The next step is to perform a relevance search. The user query is converted into a vector representation and matched against a vector database. For example, consider an intelligent chatbot that can answer an organization’s HR questions. If an employee searches for : “How much annual leave do I have?” , the system will retrieve the annual leave policy document as well as the employee’s personal past leave records. These specific documents will be returned because they are highly relevant to what the employee entered. Relevance is calculated and established using mathematical vector calculations and representations.
Enhanced LLM Tips
Next, the RAG model augments the user input (or prompt) by adding the retrieved relevant data in context. This step uses prompt engineering techniques to effectively communicate with the LLM. Augmented prompts allow the large language model to generate accurate answers to user queries.
Updating external data
The next question might be — what if the external data is out of date? To maintain current information for retrieval, update the document asynchronously and update the embedded representation of the document. You can do this through an automated real-time process or a periodic batch process. This is a common challenge in data analysis — change management can be done using different data science methods. The following figure shows the conceptual flow of using RAG with LLM.
2. What is LangChain?
It has two main capabilities:
- LLM models can be connected to external data sources
- Allows interaction with LLM models
LangChain is a framework designed for developing applications that utilize large language models (LLMs). It provides tools and abstractions to simplify the integration of LLMs into various workflows, enabling developers to create more complex and functional applications. Here are the key features and components of LangChain:
Key Features of LangChain
- Modular Architecture:
- LangChain is built with a modular approach, allowing developers to easily plug in different components (like models, retrievers, and chains) as needed.
2. Chains:
- Chains are sequences of calls like a pipline to different components (e.g., LLMs, APIs, databases) that can be combined to perform complex tasks. For example, a chain might first retrieve relevant documents and then generate responses based on those documents.
3. Retrieval:
- LangChain provides integration with various retrieval mechanisms, allowing applications to fetch relevant information from different data sources (like databases, APIs, or external knowledge bases).
4. Prompt Management:
- The framework includes utilities for managing prompts, which are essential in guiding LLMs to produce the desired output. Developers can define and customize prompts easily.
5. Memory:
- LangChain can maintain conversational memory, allowing applications to remember past interactions and provide more contextually relevant responses in ongoing conversations.
6. Integration with Other Tools:
- LangChain can connect with various data stores, web APIs, and other tools, enabling developers to build rich, interactive applications.
7. Scalability:
- The framework is designed to scale, allowing developers to handle larger datasets and more complex queries efficiently.
Code Engineering
Using LangChain to implement RAG applications in springBoot Java
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.2.1</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>rag</artifactId>
<properties>
<java.version>17</java.version>
<langchain4j.version>0.23.0</langchain4j.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
</project>
controller
import com.et.rag.service.SBotService;
import lombok.RequiredArgsConstructor;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
@Controller
@RequiredArgsConstructor
public class SBotController {
private final SBotService sBotService;
@PostMapping("/ask")
public ResponseEntity<String> ask(@RequestBody String question) {
try {
return ResponseEntity.ok(sBotService.askQuestion(question));
} catch (Exception e) {
return ResponseEntity.badRequest().body("Sorry, I can't process your question right now.");
}
}
}
Services :
import dev.langchain4j.chain.ConversationalRetrievalChain;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Service
@RequiredArgsConstructor
@Slf4j
public class SBotService {
private final ConversationalRetrievalChain chain;
public String askQuestion(String question) {
log.debug("======================================================");
log.debug("Question: " + question);
String answer = chain.execute(question);
log.debug("Answer: " + answer);
log.debug("======================================================");
return answer;
}
}
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.retriever.Retriever;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
/**
* EmbeddingStoreLoggingRetriever is a logging-enhanced for an EmbeddingStoreRetriever.
* <p>
* This class logs the relevant TextSegments discovered by the supplied
* EmbeddingStoreRetriever for improved transparency and debugging.
* <p>
* Logging happens at INFO level, printing each relevant TextSegment found
* for a given input text once the findRelevant method is called.
*/
@RequiredArgsConstructor
@Slf4j
public class EmbeddingStoreLoggingRetriever implements Retriever<TextSegment> {
private final EmbeddingStoreRetriever retriever;
@Override
public List<TextSegment> findRelevant(String text) {
List<TextSegment> relevant = retriever.findRelevant(text);
relevant.forEach(segment -> {
log.debug("=======================================================");
log.debug("Found relevant text segment: {}", segment);
});
return relevant;
}
}
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.UrlDocumentLoader;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.List;
import static com.et.rag.constant.Constants.SPRING_BOOT_RESOURCES_LIST;
@Configuration
public class DocumentConfiguration {
@Bean
public List<Document> documents() {
return SPRING_BOOT_RESOURCES_LIST.stream()
.map(url -> {
try {
return UrlDocumentLoader.load(url);
} catch (Exception e) {
throw new RuntimeException("Failed to load document from " + url, e);
}
})
.toList();
}
}
import com.et.rag.retriever.EmbeddingStoreLoggingRetriever;
import dev.langchain4j.chain.ConversationalRetrievalChain;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.time.Duration;
import java.util.List;
import static com.et.rag.constant.Constants.PROMPT_TEMPLATE_2;
@Configuration
@RequiredArgsConstructor
@Slf4j
public class LangChainConfiguration {
@Value("${langchain.api.key}")
private String apiKey;
@Value("${langchain.timeout}")
private Long timeout;
private final List<Document> documents;
@Bean
public ConversationalRetrievalChain chain() {
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
.documentSplitter(DocumentSplitters.recursive(500, 0))
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
log.info("Ingesting Spring Boot Resources ...");
ingestor.ingest(documents);
log.info("Ingested {} documents", documents.size());
EmbeddingStoreRetriever retriever = EmbeddingStoreRetriever.from(embeddingStore, embeddingModel);
EmbeddingStoreLoggingRetriever loggingRetriever = new EmbeddingStoreLoggingRetriever(retriever);
/*MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder()
.maxMessages(10)
.build();*/
log.info("Building ConversationalRetrievalChain ...");
ConversationalRetrievalChain chain = ConversationalRetrievalChain.builder()
.chatLanguageModel(OpenAiChatModel.builder()
.apiKey(apiKey)
.timeout(Duration.ofSeconds(timeout))
.build()
)
.promptTemplate(PromptTemplate.from(PROMPT_TEMPLATE_2))
//.chatMemory(chatMemory)
.retriever(loggingRetriever)
.build();
log.info("Spring Boot knowledge base is ready!");
return chain;
}
}
In summary, RAG is a powerful approach that enhances the capabilities of AI systems by combining the strengths of information retrieval and text generation.