Getting started with Vanna

Ashish Singal
Vanna AI
Published in
3 min readJul 9, 2023

Vanna is a Python-based AI SQL co-pilot. Our initial users are data-savvy data analysts, data scientists, engineers, and similar people that use Vanna to automate writing complex SQL. They primarily start using Vanna in notebooks.

All of these notebooks are on our Github. You can run them in one click on Google Colab, or you can download / clone them and run them in your local Jupyter environment.

Starter notebook (vn-starter)

Let’s get started. Vanna’s starter notebook is only a few lines of code. It’s available on Github here.

First, let’s install Vanna from pypi and import it.

%pip install vanna
import vanna as vn

Next, enter your email in front of the comment below. Vanna will send you a code that you can enter inline that will automatically set your API key.

my_email = '' # Enter your email here
vn.login(email=my_email)

Finally, we’ll choose the demo-sales dataset to work with. We have a complete rundown overview of the TPC dataset here.

vn.set_org('demo-sales')

Finally, we can run the code to generate SQL from our question, “What are the top 10 customers by sales?”. This sends the question to Vanna, which uses AI to generate SQL, and sends back the SQL. Here are more details on how Vanna works.

sql = vn.generate_sql(question='What are the top 10 customers by Sales?')
print(sql)

And here are the results —

SELECT customer_name,
total_sales
FROM (SELECT c.c_name as customer_name,
sum(l.l_extendedprice * (1 - l.l_discount)) as total_sales,
row_number() OVER (ORDER BY sum(l.l_extendedprice * (1 - l.l_discount)) desc) as row_num
FROM snowflake_sample_data.tpch_sf1.lineitem l join snowflake_sample_data.tpch_sf1.orders o
ON l.l_orderkey = o.o_orderkey join snowflake_sample_data.tpch_sf1.customer c
ON o.o_custkey = c.c_custkey
GROUP BY customer_name)
WHERE row_num <= 10;

Vanna on your schema (vn-offline)

The notebook above uses a demo organization hosted by Vanna. However, you’ll want to create your own organization and train Vanna on your schema. The second notebook, vn-offline.ipynb, does this.

First, let’s create an org. Note that the org name needs to be globally unique or you’ll get an error.

my_org = '' # Globally unique org identifier
vn.set_org(my_org)

Now let’s train Vanna. We do this by entering question / SQL pairs that are correct as below —

vn.store_sql(
question='What are the top 10 customers?',
sql='SELECT customer_name, sales FROM customers ORDER BY sales desc LIMIT 10'
)

vn.store_sql(
question='What are the top 10 customers in the US?',
sql="SELECT customer_name, sales FROM customers WHERE country_name = 'UNITED STATES' ORDER BY sales desc LIMIT 10"
)

Finally, we can now ask questions in plain English to generate SQL automatically, as below —

sql = vn.generate_sql('How many customers are there in each region?')
print(sql)

And Vanna generates the following SQL —

SELECT region,
count(distinct customer_id) as num_customers
FROM customers
GROUP BY region

The questions you ask won’t always generate correct SQL if you haven’t properly trained Vanna. See this post for more information on properly training Vanna to your schema.

Securely executing SQL using Vanna (vn-full)

The notebooks above just generated SQL. Likely, though, you want to not just generate the SQL, but actually execute it and generate a corresponding chart as well. Let’s look at how this works —

First, instead of training data question by question like in the previous example, we simply use a pre-existing JSON that contains a couple dozen questions —

for _, row in pd.read_json("https://raw.githubusercontent.com/vanna-ai/vanna-training-queries/main/tpc-h/questions.json").iterrows():
vn.train(question=row.question, sql=row.sql)

Next, instead of just getting the SQL for our questions, we’ll actually run it.

conn = snowflake.connector.connect(user=user, password=password, account=account, database=database)
cs = conn.cursor()

# This function is provided as a convenience. You can choose to run your SQL
# however you normally do.
df = vn.get_results(cs, database, sql)
df

The result is a pandas dataframe —

We can even generate a Plotly chart automatically —

plotly_code = vn.generate_plotly_code(question=my_question, sql=sql, df=df)
fig = vn.get_plotly_figure(plotly_code=plotly_code, df=df)
fig.show()

For a deeper look at how Vanna works, check out this post.

--

--