Enhancing SQL Agents with Retrieval Augmented Generation (RAG)

Luc Nguyen
6 min readDec 21, 2023

--

In the previous blog post, we delved into the process of constructing an SQL Agent to help us answer questions by querying data in the database. In this article, let’s explore how to enhance the capabilities of your SQL Agent by incorporating advanced analytics functions. Imagine that the Agent is not only capable of providing basic statistical numbers, such as the average amount of money a customer paid, but also has the ability to offer more advanced and intriguing insights. This includes tasks like identifying similarities between users or products in the database or determining the route path of users who frequently cancel their memberships. Let’s discuss how to achieve these advanced functionalities.

Teradata’s Advanced Analytics function

Unlike other databases, Teradata stands out by offering a plethora of advanced analytics functions, spanning from Data Cleaning and Data Exploration to Model Training, Text Analytics, and Path and Pattern Analysis functions.

The distinctive feature is that all these functions can seamlessly run in-database, eliminating the need for you to set up separate environments. When you execute these functions, they are processed directly within the database, ensuring high performance.

Tables

For instance, consider two tables in the database: UserHistory and UserHistoryReferences. Using the TD_VectorDistance function, you can find similar users between these tables. The query syntax is as follows:

SELECT target_id, reference_id, distancetype, CAST(distance AS DECIMAL(36,8)) AS distance
FROM TD_VECTORDISTANCE (
ON target_mobile_data_dense AS TargetTable
ON ref_mobile_data_dense AS ReferenceTable DIMENSION
USING
TargetIDColumn('userid')
TargetFeatureColumns('CallDuration','DataCounter','SMS')
RefIDColumn('userid')
RefFeatureColumns('CallDuration','DataCounter','SMS')
DistanceMeasure('cosine')
TopK(2)
) AS dt ORDER BY 3,1,2,4;

And here are result from DB:

Target_ID    Reference_ID   DistanceType   Distance
--------- ------------- ------------- --------
1 5 cosine 0.45486518
1 7 cosine 0.32604815
2 5 cosine 0.02608923
2 7 cosine 0.00797609
3 5 cosine 0.02415054
3 7 cosine 0.00337338
4 5 cosine 0.43822243
4 7 cosine 0.31184844

For Teradata’s advanced analytics details, check the documentation at here.

Retrieval Augmented Generation (RAG)

To facilitate your agent’s understanding of how to use these functions, I propose employing a technique known as Retrieval Augmented Generation (RAG).

This approach aids in locating relevant instructions based on the query. For instance, if I ask my agent to assist me in finding similar users based on the tables UserHistory and UserHistoryReferences, RAG will efficiently return the appropriate syntax and examples related to this request.

Syntax Instruction

For optimal performance of the SQL Agent, Syntax Instructions should contain two essential pieces of information. First, include the syntax with explanations for each parameter. Second, and most importantly, provide examples. The more examples you provide, the more accurate the SQL syntax generated by the Agent will be.

Let’s Construct RAG

To create a RAG system, begin by preparing documents. Convert these documents into vectors and save them in a Vector Database, which we’ll refer to as Vector DB. In this example, I’ll be using a Vector DB named FAISS.

# Import require lib
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS

Start by reviewing the documentation provided by Teradata. First, prepare syntax instructions that include explanations and examples.

syntax_1 = """
Syntax Description :
TD_VectorDistance () ...

Example :
TD_VectorDistance ( ... )
"""

syntax_2 = """
...
"""

syntax_3 = """
...
"""

Next, leverage various Open Source models from platforms like Hugging Face or OpenAI Embedding Service. In this instance, I utilized OpenAI for this task.

embedding_function = OpenAIEmbeddings(openai_api_key=os.getenv("OPENAI_API"))

Finally, with the assistance of Langchain and the FAISS Database, you can complete the process with just few line of code

technical_list = [syntax_1, syntax_2, syntax_3, ..syntax_n]

db = FAISS.from_texts(technical_list, embedding_function)

You can easily search for relevant documents in the database using the simple code below. For instance, if you’re looking for the syntax used to calculate similarity, the following code will return the exact syntax you prepared in the previous step that is relevant to your query:

db.similarity_search("Calculate similarity")[0]

Integrating RAG with SQL Agent

We have already covered how to create an SQL Agent in a previous blog post. If you are not familiar with the process, please refer to that blog. Additionally, we have discussed creating a RAG to retrieve relevant syntax instruction information. Now, let’s explore how to integrate these two components seamlessly.

RAG as a tool

In this blog, I’ve detailed how the SQL Agent utilizes tools like sql_db_list_tables to interact with the database. Now, my concept is to designate RAG as another tool. This allows the SQL Agent to decide when to explore relevant documents and identify the most suitable keywords for searching when needed.

Create customize tool

To create a custom tool using Langchain, extend the BaseTool class provided by Langchain and customize the _run function as follows. It's crucial to keep the description clear to ensure that the SQL Agent understands the tool's purpose.

from langchain.tools import BaseTool
from typing import Optional
from langchain.callback_manager import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun

# Define retriever
retriever = db.as_retriever()

# Define customize tool
class TeradataSearchTool(BaseTool):
name = "teradata_search_tool"
description = "Input to this tool is a keyword such as binning or bucketing, similarity, moving average. Output is an instruction on how to use Teradata Syntax with examples to improve queries."

def _run(
self, query: str, run_manager: Optional[CallbackManagerForToolRun] = None
) -> str:
"""Use the tool."""
global retriever
relevant_doc = retriever.get_relevant_documents(query)
if len(relevant_doc) == 0 or len(query) == 0:
return "There are no Teradata syntax examples to be used in this scenario."
else:
return relevant_doc[0].page_content

async def _arun(
self, query: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None
) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")

# Init teradata search tool
teradata_search_tool = TeradataSearchTool()

Create SQL Agent with customize tools

After defining the Teradata Search tool, you can create an SQL Agent using the following code. Additionally, you can include the `teradata_search_tool` created in the previous step in the `extra_tools` section.

# Step 4. Create Agent Executor 
sql_agent = create_sql_agent(
llm=model,
toolkit=toolkit,
verbose=True,
agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
extra_tools=[teradata_search_tool],
prefix=prefix,
suffix=suffix
)

Finally, test it

agent_executor.run("Identify user similarities by analyzing the 'UserHistory' table using 'UserHistoryReference' as the reference table, focusing on attributes CallDuration, DataCounter, and SMS")

When I ask the agent to help me identify user similarities by analyzing the ‘UserHistory’ table using ‘UserHistoryReference’ as the reference table, with a focus on attributes such as CallDuration, DataCounter, and SMS, here are the results.

Result from Agent

Conclusion

By combining SQL Agent with RAG, we elevate the power of the LLM model to the next level. This approach enables the creation of another RAG that empowers your agent to answer questions based on both structured data and text data. However, it’s essential to acknowledge that there are still some issues related to token limits. In the next blog, I will delve into the discussion on fine-tuning the model to enable your agent to perform the same tasks without relying on RAG.

You can find full code here!

--

--