Deploying a Fine-tuned GPT-3.5 NL- to-SQL Model

Mo Pourreza
Dataherald
Published in
11 min readSep 12, 2023
Photo by Markus Spiske on Unsplash

Background

In the previous article, we explored the process of fine-tuning a GPT-3.5-turbo model for a NL-2-SQL task. In this post, we will deploy our fine-tuned model using the open source Dataherald engine. This integration will combine the fine-tuned model with a RAG system, improving SQL generation performance.

The modular structure of the engine allows us to easily plug in a custom SQL generation module, while still leveraging other parts of the engine including its context store manager and evaluator modules.

Let’s dive in.

Implementation

To set up an API with our custom fine-tuned model (or any other SQL generation method) we need to do the following:

  1. Create a custom class to return the database schema and content descriptions in a format consistent with our fine-tuning training set. We will call this class FineTuningDatabaseContentCreator in this tutorial
  2. Create a new SQL generator class that inherits from the SQLGenerator base class and implements its abstract methods.
  3. Set up and run the engine in a Docker container, configuring it to use our custom class

Building the FineTuningDatabaseContentCreator class

While the Dataherald engine comes with its own built-in Database scanner API, we will not use in this tutorial since the formatting it used is not consistent with our fine-tuned model. Instead we will build a custom class to get the Database schema.

Our fine-tuned model relies on three key pieces of information within the prompts to generate a SQL query. These components are:

  1. Task-Specific Instruction: This section provides explicit instructions to the LLM, defining the task by specifying the input requirements and detailing the expected output. It serves as a roadmap for the model to understand the task it needs to accomplish.
  2. Database Schema and Content Description: The second part encompasses descriptions of the database, including both its schema (the structure of tables and columns) and the content contained within. It clarifies how data is organized within tables and columns.
  3. Given Question: Naturally, the question itself is a crucial input for generating an appropriate SQL query.

We will create a class called FineTuningDatabaseContentCreator which will take a SQLAlchemy engine as an input and structure the database schema and content in the following format which our fine-tuned model expects:

CREATE TABLE trip (
id INTEGER, duration INTEGER,
start_date TEXT,
start_station_name TEXT,
start_station_id INTEGER,
end_date TEXT,
end_station_name TEXT,
end_station_id INTEGER,
bike_id INTEGER,
subscription_type TEXT,
zip_code INTEGER,
PRIMARY KEY (id)
)
/* Columns in trip and 3 examples in each column for high cardinality columns :
id : 900645, 900752, 900524
duration : 1131, 2146, 1155
start_date : 8/21/2015 17:39, 8/21/2015 17:03, 8/21/2015 17:16
start_station_name : Howard at 2nd, 2nd at Folsom, Market at 10th
start_station_id : 56, 65, 49 end_date : 8/21/2015 17:19, 8/21/2015 18:08, 8/21/2015 17:32
end_station_name : Howard at 2nd, 2nd at Folsom, Market at 10th
end_station_id : 56, 65, 49
bike_id : 586, 56, 65
zip_code : 94070, 94530, 94040–1724
*/
/* Columns in trip and all categories for low cardinality columns :
subscription_type : Customer, Subscriber
*/

CREATE TABLE management (
"department_ID" INTEGER,
"head_ID" INTEGER,
temporary_acting TEXT,
PRIMARY KEY ("department_ID", "head_ID"),
FOREIGN KEY("head_ID") REFERENCES head ("head_ID"),
FOREIGN KEY("department_ID") REFERENCES department ("Department_ID")
)
/* Columns in management and all categories for low cardinality columns :
department_ID : 7, 15, 2, 11
head_ID : 5, 4, 6, 3, 10
temporary_acting : Yes, No
*/
...

We first fork and clone the Dataherald repository to create a local copy from which we will work.

Next, we’ll generate a new file named database_content_creator.py within the sql_generator directory, which will contain the following code:

import warnings
from typing import Iterable, List

from sqlalchemy import MetaData, Table, inspect, select
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.schema import CreateTable

warnings.filterwarnings("ignore")

class FineTuningDatabaseContentCreator:
def __init__(
self,
engine: Engine,
schema: str | None = None,
metadata: MetaData | None = None,
ignore_tables: List[str] | None = None,
include_tables: List[str] | None = None,
sample_rows_in_table_info: int = 3,
low_cardinality_threshold: int = 10,
indexes_in_table_info: bool = False,
custom_table_info: dict | None = None,
view_support: bool = False,
max_string_length: int = 300,
):
self._engine = engine
self._schema = schema

if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")

self._inspector = inspect(self._engine)

# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=schema)
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
)

self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
missing_tables = self._include_tables - self._all_tables
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)

self._ignore_tables = set(ignore_tables) if ignore_tables else set()
if self._ignore_tables:
missing_tables = self._ignore_tables - self._all_tables
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)

usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables

if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")

if not isinstance(low_cardinality_threshold, int):
raise TypeError("low_cardinality_threshold must be an integer")

self._sample_rows_in_table_info = sample_rows_in_table_info
self._low_cardinality_threshold = low_cardinality_threshold
self._indexes_in_table_info = indexes_in_table_info

self._custom_table_info = custom_table_info
if self._custom_table_info:
if not isinstance(self._custom_table_info, dict):
raise TypeError(
"table_info must be a dictionary with table names as keys and the "
"desired table info as values"
)
# only keep the tables that are also present in the database
intersection = set(self._custom_table_info).intersection(self._all_tables)
self._custom_table_info = {
table: self._custom_table_info[table]
for table in self._custom_table_info
if table in intersection
}

self._max_string_length = max_string_length

self._metadata = metadata or MetaData()
# including view support if view_support = true
self._metadata.reflect(
views=view_support,
bind=self._engine,
only=list(self._usable_tables),
schema=self._schema,
)

@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name

def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return sorted(self._include_tables)
return sorted(self._all_tables - self._ignore_tables)

@property
def table_info(self) -> str:
"""Information about all tables in the database."""
return self.get_table_info()

def get_table_info(self, table_names: List[str] | None = None) -> str:
"""
Get information about specified tables.
"""
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names

meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]

tables = []
for table in meta_tables:
if self._custom_table_info and table.name in self._custom_table_info:
tables.append(self._custom_table_info[table.name])
continue

# add create table command
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
if self._sample_rows_in_table_info:
try:
table_info += f"\n{self._get_sample_rows(table)}\n"
except Exception: # noqa: S110
pass
tables.append(table_info)
tables.sort()
return "\n\n".join(tables)

def _get_sample_rows(self, table: Table) -> str:

limiting_factor = 200
# build the select command
command = select(table).limit(limiting_factor)

try:
with self._engine.connect() as connection:
response = ""
sample_rows_result = connection.execute(command)
sample_rows = sample_rows_result.fetchall()

# Create sections for high and low cardinality columns
high_cardinality_section = f"/*\nColumns in {table.name} and {str(self._sample_rows_in_table_info)} examples in each column for high cardinality columns :" # noqa: E501
low_cardinality_section = f"/*\nColumns in {table.name} and all categories for low cardinality columns :" # noqa: E501

low_columns = ""
high_columns = ""
for column, index in zip(table.columns,range(len(table.columns))): # noqa: B905
column_name = column.name
values = [str(row[index]) for row in sample_rows]

# Determine if the column is high or low cardinality based on the threshold # noqa: E501
unique_values = set(values)
if len(unique_values) > self._low_cardinality_threshold:
high_columns += f"\n{column_name} : {', '.join(list(unique_values)[:self._sample_rows_in_table_info])}" # noqa: E501
else:
low_columns += f"\n{column_name} : {', '.join(unique_values)}" # noqa: E501

if high_columns:
high_cardinality_section += high_columns + "\n*/\n"
response += high_cardinality_section

if low_columns:
low_cardinality_section += low_columns + "\n*/"
response += low_cardinality_section

except ProgrammingError:
response = ""

return response

The get_table_info function retrieves information about specified database tables or, if not specified, all usable tables. This information includes:

  1. The SQL CREATE TABLE commands for each table
  2. Additional details such as sample rows from each table
  3. Categorization of columns as high or low cardinality based on a threshold and examples for low cardinality columns.

We can create the database schema and database content for a given database URI as follows, and will do so in our new SQL generator.

from dataherald.sql_generator.database_content_creator import FineTuningDatabaseContentCreator
from sqlalchemy import create_engine

engine = create_engine(database_uri)
db = FineTuningDatabaseContentCreator(engine)
database_content = db.get_table_info()

Implementing a custom SQL generator

We will now create a custom implementation of the SQLGenerator base class calledFineTunedGPT. To do this, we first create a new file named fine_tuned_gpt.py within the sql_generator directory. This class inherits from the SQLGenerator class and will override the generate_response() function.

Therefore, our fine_tuned_gpt.py class signature looks as follows

from dataherald.types import NLQuery, NLQueryResponse
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_generator import SQLGenerator
class FineTunedGPT(SQLGenerator):
@override
def generate_response(
self,
user_question: NLQuery,
database_connection: DatabaseConnection,
context: List[dict] = None,
) -> NLQueryResponse:

Now, let’s implement the generate_response() method. This method takes in three parameters:

  1. A NLQuery object containing the test question
  2. A DatabaseConnection object
  3. A context list. The context list comprises previously asked questions similar to the current Natural Language question along with their corresponding SQL queries. These are used as few-shot demonstrations to the LLM.

As a preliminary step, we’ll establish a connection to the database and formulate the system prompt by including the instruction and database schema and content descriptions as described in the first step. We also need to to decrypt the database URI stored by the engine since it is encrypted by default. This will be accomplished as follows:

fernet_encrypt = FernetEncrypt()
engine = create_engine(unquote(fernet_encrypt.decrypt(database_connection.uri)))
db = FineTuningDatabaseContentCreator(engine)
instruction = f"""
You are an assistant that is an expert in generating {db.dialect} SQL queries.
Having the access to database content, generate a correct {db.dialect} SQL query for the given question.
### Database content ###
"""
database_content = db.get_table_info()
system_prompt = instruction + database_content

Once the system prompt is assembled, we are ready to invoke the LLM to generate the answer. However, in order to improve accuracy we will use one of the main built-in features of the Dataherald engine: retrieving similar verified questions as few-shot samples to be used within the prompt.

To incorporate the few-shot examples within the prompts, we will leverage the provided context list as follows:

if context is not None:
samples_prompt_string = "The following are some similar previous questions and their correct SQL queries from the database: \
\n"
for sample in context:
samples_prompt_string += (
f"Question: {sample['nl_question']} \nSQL: {sample['sql_query']} \n"
)

question_with_context = (
f"{user_question.question} An example of a similar question and the query that was generated \
to answer it is the following {samples_prompt_string}"
if context is not None
else user_question.question
)

Now it’s time to send the request to our fine-tuned model and get the SQL query as the result.

response = None
while response is None:
try:
response = openai.ChatCompletion.create(
model="ft:gpt-3.5-turbo-0613:dataherald:spider:7t2q6Qhd",
api_key=os.environ.get("OPENAI_API_KEY"),
temperature=0.0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question_with_context}
]
)
except Exception as e:
print(e)
continue
model_response = response.choices[0]['message']['content']

In order to retrieve the natural language response in addition to the SQL, we need to employ the GeneratesNlAnswer class from Dataherald's SQL generators. This class helps us find the natural language answer for a given question by considering the generated SQL query and its corresponding execution result. The process is as follows:

nl_query_response = NLQueryResponse(
nl_question_id=user_question.id,
sql_query=sql,
)
generates_nl_answer = GeneratesNlAnswer(self.system, self.system.instance(DB))
generates_nl_answer.execute(nl_query_response)

Putting all of this together gives us the final FineTunedGPT SQL generator class:

import logging
import os
import re
from typing import List
from urllib.parse import unquote

import openai
from overrides import override
from sqlalchemy import create_engine

from dataherald.db import DB
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_generator import SQLGenerator
from dataherald.sql_generator.database_content_creator import FineTuningDatabaseContentCreator
from dataherald.sql_generator.generates_nl_answer import GeneratesNlAnswer
from dataherald.types import NLQuery, NLQueryResponse
from dataherald.utils.encrypt import FernetEncrypt

logger = logging.getLogger(__name__)


class FineTunedGPT(SQLGenerator):
def output_parser(self, model_output: str) -> str:
pattern = r'The SQL query I\'ll be generating is:(.*?)$'
match = re.search(pattern, model_output, re.DOTALL)
if match:
sql = match.group(1).strip()
else:
sql = model_output
re_combine_whitespace = re.compile(r"\s+")
return re_combine_whitespace.sub(" ", sql).strip()


@override
def generate_response(
self,
user_question: NLQuery,
database_connection: DatabaseConnection,
context: List[dict] = None,
) -> NLQueryResponse:
fernet_encrypt = FernetEncrypt()
engine = create_engine(unquote(fernet_encrypt.decrypt(database_connection.uri)))
db = FineTuningDatabaseContentCreator(engine)
instruction = f"""
You are an assistant that is an expert in generating {db.dialect} SQL queries.
Having the access to database content, generate a correct {db.dialect} SQL query for the given question.
### Database content ###
"""
database_content = db.get_table_info()
system_prompt = instruction + database_content
if context is not None:
samples_prompt_string = "The following are some similar previous questions and their correct SQL queries from the database: \
\n"
for sample in context:
samples_prompt_string += (
f"Question: {sample['nl_question']} \nSQL: {sample['sql_query']} \n"
)

question_with_context = (
f"{user_question.question} An example of a similar question and the query that was generated \
to answer it is the following {samples_prompt_string}"
if context is not None
else user_question.question
)
response = None
while response is None:
try:
response = openai.ChatCompletion.create(
model="ft:gpt-3.5-turbo-0613:dataherald:spider:7t2q6Qhd",
api_key=os.environ.get("OPENAI_API_KEY"),
temperature=0.0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": question_with_context}
]
)
except Exception as e:
print(e)
continue
model_response = response.choices[0]['message']['content']
sql = self.output_parser(model_response)
nl_query_response = NLQueryResponse(
nl_question_id=user_question.id,
sql_query=sql,
)
generates_nl_answer = GeneratesNlAnswer(self.system, self.system.instance(DB))
return generates_nl_answer.execute(nl_query_response)

Now let’s use our new SQL generator.

Setting up the engine

First we need to configure the environment variables, which can be done using the following command:

cp .env.example .env

Next we will create an encryption key which is used by the engine to store DB connection in MongoDB with the following command:

#install the cryptography package
pip3 install cryptography

#run ptython in terminal
python3

#import Fernet
from cryptography.fernet import Fernet

# Generate the key
Fernet.generate_key()

Next, we will configure the necessary environment variables:

#OpenAI credentials
OPENAI_API_KEY =
ORG_ID =

#Encryption key for storing DB connection data in Mongo
ENCRYPT_KEY =

#change the SQL generator to our new FineTunedGPT class
SQL_GENERATOR = "dataherald.sql_generator.fine_tuned_gpt.FineTunedGPT"

Next we have to create a Docker network for communication between services and build the docker container

docker network create backendnetwork
docker compose up --build

Now we should be able to see the API endpoints by visiting the http://localhost/docs

Endpoints of the Detaherald engine

Testing the solution

Now, we can connect to the Dataherald engine and one of the Spider databases to pose questions. For the purpose of this article, we’ll be focusing on the concert_singer database. To link the engine to the concert_singer database, we’ll make use of the v1/database endpoint using the following call:

curl -X 'POST' \
'http://localhost/api/v1/database' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"db_alias": "concert_singer",
"use_ssh": false,
"connection_uri": "sqlite:///concert_singer.sqlite",
"ssh_settings": null
}'

Please keep in mind that you’ll need to adjust the connection_uri according to the relative path of the SQLite database file.

Finally we can ask questions by using the v1/question endpoint. For example:

curl -X 'POST' \
'http://localhost/api/v1/question' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"question": "How many singers do we have?",
"db_alias": "concert_singer"
}'

Will result in the following response:

{
"id": {
"$oid": "64fa0456dd28b057f30d98c4"
},
"nl_question_id": {
"$oid": "64fa0452dd28b057f30d98c3"
},
"nl_response": "We have 6 singers.",
"intermediate_steps": null,
"sql_query": "SELECT count(*) FROM singer",
"sql_query_result": {
"columns": [
"count(*)"
],
"rows": [
{
"count(*)": 6
}
]
},
"sql_generation_status": "VALID",
"error_message": null,
"exec_time": 4.125343084335327,
"total_tokens": null,
"total_cost": null,
"confidence_score": 1
}

As you can see from the above, we got the correct result from our model which is 6.

Conclusion

In this post, we set up an API from the GPT-3.5 model we fine-tuned for NL-to-SQL tasks in our previous blog. By using the Dataherald engine for this task, we get complex capabilities such as confidence score prediction and few-shot example selection out of the box. The modular design of the engine allowed us to integrate our model easily by simply creating a new SQL generator class.

You can access the forked repository, along with the created classes here. Join our discord to learn more about the project.

References

Our previous article on how to fine-tune GPT-3.5-Turbo: https://medium.com/dataherald/fine-tuning-gpt-3-5-turbo-for-natural-language-to-sql-4445c1d37f7c

About Dataherald

  • Sign up for free and use the hosted version of Dataherald
  • Our open-source engine is available on Github.
  • Join our Discord server to learn more about the project.

--

--