FastAPI, SQLAlchemy, and Parallel Queries Walk Into a Bar…

Liron Ben Yeda
5 min readMay 10, 2023

--

As a tech lead at Via, I was involved in developing a new service aimed at retrieving data from a database and processing it as needed. Our technology stack included FastAPI, SQLAlchemy (Postgres DB), and Python 3.11.

Before we began implementing our new service, we took some time to research the best practices for using FastAPI + SQLAlchemy.
During our research, we found that it was a common recommendation to use get_db_session as a dependency.
We came across this quote 👇🏼 on the FastAPI website

“ We need to have an independent database session/connection (SessionLocal) per request, use the same session through all the request and then close it after the request is finished. “

We learned about the motivation behind using the session as a dependency. Seeing this approach being used repeatedly in multiple examples gave us the confidence that it was the right choice. As a result, we decided to follow the same method in our project.

We had all the necessary preparations in place and were eager to get started building our service.

Our service was coming together nicely, with various endpoints already in place. However, things took a turn when we encountered a scenario where we needed to execute multiple queries to the database simultaneously within the same request 😈

So, you’re probably thinking, “What’s the problem here?” Well, we were thinking the same thing at first. We were like, “We got this, we have TaskGroup, no biggie.”
But turns out, TaskGroup can’t solve all of life’s problems. Who knew?

Let's see some code.

async def get_db_session():
db_session = DBSession()
try:
yield db_session
finally:
db_session.close()

@app.get("/total_books")
async def get_total_books(db_session: AsyncSession = Depends(get_db_session)):
return await crud.get_book_table_total_count(db_session)
async def get_book_table_total_count(db_session: AsyncSession):
result = await db_session.execute(select(count()).select_from(Book))
return result.scalar() or 0

So, what do we have here?
We have an endpoint that retrieves the total number of books in the database.

Avg. elapsed time: 2.258295 seconds

Let’s say we have three independent queries that need to be executed.
Why bother doing it synchronously when we can do it in parallel?

* Just a heads up, for this example, I’ve used the same query 3 times to illustrate the point. In reality, these queries would likely be different, but independent of each other.

@app.get("/multiple_counts")
async def get_multiple_counts(db_session: AsyncSession = Depends(get_db_session)):
async with TaskGroup() as tg:
task1 = tg.create_task(crud.get_book_table_total_count(db_session))
task2 = tg.create_task(crud.get_book_table_total_count(db_session))
task3 = tg.create_task(crud.get_book_table_total_count(db_session))
return task1.result(), task2.result(), task3.result()

So, how long do you think it will take to execute those 3 queries in parallel? Well, I expect it to take almost the same time as the previous endpoint, maybe a little bit longer due to the overhead of using a TaskGroup.

But guess what?

Avg. elapsed time: 6.493947 seconds

Yep, we were shocked too…

We expected that running 3 queries in parallel would take the same amount of time as running one, but it actually took three times longer. This suggests that there is something synchronous happening that we need to investigate ⛏️.

Through our investigation, we found this:

“ The Session object is entirely designed to be used in a non-concurrent fashion, which in terms of multithreading means “only in one thread at a time”. ”

To solve this issue, SQLAlchemy recommends using async_scoped_session, which manages Session objects with scoped management. Here’s an example from the official site:

from asyncio import current_task

from sqlalchemy.ext.asyncio import async_scoped_session, async_sessionmaker

async_session_factory = async_sessionmaker(
some_async_engine, expire_on_commit=False
)
AsyncScopedSession = async_scoped_session(
async_session_factory, scopefunc=current_task
)
some_async_session = AsyncScopedSession()

By using current_task as scopefunc, a new session will be created for each new coroutine - Nice! this is exactly what we need.
The sessions are stored in the registry , which is a dictionary where the scopefunc output is the key and the value is an AsyncSession object.

To properly close the session and remove it from the registry, it is necessary to call the remove() function from within the outermost awaitable.

So we realized that we can’t keep using a single session per request, and we must switch to using async_scoped_session in our project. This means that we’ll have to replace all occurrences of Depends(get_db_session) with an appropriate alternative. Additionally, we need to ensure that we call remove() to avoid memory leaks and open sessions..

TLDR; We created the DBManager class and a new dependency

from asyncio import current_task, TaskGroup
from typing import AsyncGenerator, Iterable

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import async_scoped_session

from db.db_connector import SqlAlchemyConnector

class DBManager:
def __init__(self):
connector = SqlAlchemyConnector()
self.scoped_session_factory = \
async_scoped_session(connector.session_factory,
scopefunc=_get_current_task_id)

def get_session(self) -> AsyncSession:
return self.scoped_session_factory()


# New dependency
async def get_db_manager() -> AsyncGenerator[DBManager, None]:
db_manager = DBManager()
try:
yield db_manager
finally:
sessions = db_manager.scoped_session_factory.registry.registry.values()
await _close_sessions(sessions)


async def _close_sessions(db_sessions: Iterable[AsyncSession]):
async with TaskGroup() as task_group:
for db_session in db_sessions:
task_group.create_task(db_session.close())


def _get_current_task_id() -> int:
return id(current_task())

What do we have here?

  1. The DBManager class creates an instance of async_scoped_session and provides the get_session function, which returns either an existing session for the current scope or creates a new one if it doesn’t exist yet.
  2. get_db_manager is our new dependency. It creates a DBManager instance and yields it. After the request is finished, we will enter the finally block and close all the sessions that were opened.
  3. The _close_sessions() function uses a TaskGroup to create tasks that close each session asynchronously.
  4. The _get_current_task_id() function returns the unique identifier of the current Task object, which is used as the scope identifier for the async_scoped_session.

How does it look now that we’re using the new dependency?

@app.get("/multiple_counts")
async def get_multiple_counts(db_manager: DBManager = Depends(get_db_manager)):
async with TaskGroup() as tg:
task1 = tg.create_task(crud.get_book_table_total_count(db_manager))
task2 = tg.create_task(crud.get_book_table_total_count(db_manager))
task3 = tg.create_task(crud.get_book_table_total_count(db_manager))
return task1.result(), task2.result(), task3.result()
async def get_book_table_total_count(db_manager: DBManager):
db_session = db_manager.get_session()
result = await db_session.execute(select(count()).select_from(Book))
return result.scalar() or 0

And now, as expected 😇

Avg. elapsed time: 2.741495 seconds

By using the DBManager as a dependency, we can create sessions anywhere in the code (in any layer, BL, DAL) without having to worry about handling them ourselves. It is also very convenient for testing (override the dependency or override the get_session function). Additionally, we can create utility functions within the class, such as the following:

async def execute_queries_in_parallel(self, queries: Dict[str, Select]) -> Dict[str, Task[Result]]:
tasks = {}

async with TaskGroup() as task_group:
for query_name, query in queries.items():
# We need the `execute_query` function because we want the `get_session()` method
# to be called in the scope of each new task, rather than in the scope of the
# root task. This ensures that each query is executed in its own session.
task = task_group.create_task(self.execute_query(query))
tasks[query_name] = task

return tasks


async def execute_query(self, query: Select) -> Result:
db_session = self.get_session()
return await db_session.execute(query)
async def get_multiple_counts(db_manager: DBManager):
count_query = select(count()).select_from(Book)
queries = {
"count_1": count_query.filter(Book.author == "Author 1"),
"count_2": count_query.filter(Book.author == "Author 2"),
"count_3": count_query.filter(Book.author == "Author 3"),
}
result = await db_manager.execute_queries_in_parallel(queries=queries)

count_1 = result["count_1"].result().scalar() or 0
count_2 = result["count_2"].result().scalar() or 0
count_3 = result["count_3"].result().scalar() or 0

return count_1, count_2, count_3

This is just an example, but you can leverage this class to create any utility function that you need, keeping your code DRY (Don’t Repeat Yourself) 🍸

So until next time, keep coding like a boss, and may the async be with you! 💪🏼

Photo by Lala Azizli on Unsplash

--

--