Speed, speed, and more speed!

Optimize Count(*) when SELECTing data with filters

Jesum
4 min readJul 11, 2023

--

We recently had the following ask from the front-end UI team that we needed to implement in our REST APIs: I want to fetch a set of records from the backend database with filter clauses such as WHERE, ORDER, GROUPBY etc and be able to show a UI element for page navigation.

This means if:

  1. Your table has 1million rows.
  2. Your select statement returns 20,000 rows.
  3. Your page control UI element must be rendered based on #2 — not #1.

This gets a bit more complicated when you include LIMIT and OFFSET clauses (well, yeah, coz our APIs support paging!).

There are 2 fundamental ways to do this with most RDBMS databases:

  1. Run a dual-SQL query approach. You do a SELECT * bla bla bla to fetch all your records. Then run another SELECT COUNT(*) bla bla bla to get a row count. You can speed things up by caching both queries.
  2. Run a single-SQL query approach. You perform a COUNT(*) in the same SELECT statement by using the OVER() clause (but the partition is across the entire set of rows returned). You can also speed things up by caching this little fella.

Intuitively, we expected #2 to be a lot quicker than #1 (at the cost of a network latency and local memory). The gap between #1 and #2 widens as your SELECT statement returns more and more rows.

In both approaches, we have to be careful to drop the LIMIT and OFFSET clauses (but keep all the other clauses) or the results of the row count will be wrong.

We power our backend using SQLAlchemy (would have been a nightmare otherwise to parse SQL statements manually to solve the above problem). So here’s what we ended up doing: we supported BOTH approaches that can be toggled via a variable in our source code — hahaha — just coz we could :P

The first approach is pretty straight forward — assuming you have a SQLAlchemy Query object, you can do this for the 2nd SQL call:

# Assume my_query = Query object
# First, drop both limit and offset clauses.
new_query = my_query.limit(None).offset(None)
dbms_records = new_query.fetchall()
record_count = len(dbms_records)

Becareful!! We initially used my_query.limit(None) only and the resulting SQL string was SELECT …… LIMIT -1 OFFSET …. which most databases would reject. What’s a LIMIT -1 ??????!

The second approach was a bit more tricky. Here’s the code to do it:

_count_column_name = "some_crazy_col_name_that_will_never_exist"
new_query = my_query.add_columns(func.count().over().label(_count_column_name))
dbms_records = new_query.fetchall()
record_count = dbms_records[0][_count_column_name]

This nets you an array of records where each row has one additional column (named by the variable _count_column_name) and you can simply take the value in that column from the first row of the results. You do not even have to remove the LIMIT and OFFSET clauses.

Then comes the next step — you can’t return this result set naked as it is. The caller will be wondering, “What’s this additional column?”

And so we have to remove that from the results.

In our database wrapper library, all database query results are returned as a list of dictionaries. We needed to know which method is faster for our scenario:

  1. Creating a new list of dictionaries while removing that column from every dictionary in the list; or
  2. Dynamically shrink the dictionaries in the list — one at a time.

We wrote a little test and the results speak for themselves:

import random
import time

def generate_list_of_dicts(dict_list, num_elements):
list_of_dicts = [random.choice(dict_list) for _ in range(num_elements)]
return list_of_dicts

dict_list = [
{'key1': 'value1', 'key2': 'value2'},
{'key1': 'value3', 'key2': 'value4'},
{'key1': 'value5', 'key2': 'value6'}
]
list_of_dicts = generate_list_of_dicts(dict_list, 10000000)
key_to_delete = 'key2'

start_time = time.time()
for d in list_of_dicts:
if key_to_delete in d:
del d[key_to_delete]
end_time = time.time()
print(f"{end_time - start_time}")

list_of_dicts = generate_list_of_dicts(dict_list, 10000000)
start_time = time.time()
list_of_dicts = [{k: v for k, v in d.items() if k != key_to_delete} for d in list_of_dicts]
end_time = time.time()
print(f"{end_time - start_time}")

The final output of the code above?

0.5273728370666504
5.8756279945373535

#1 is about 11 times faster than #2. :)

So, to solve the problem of allowing front-end UI to be able to render paging controls on a SELECT statement with filters, we use the OVER() window function to do a COUNT(*) of the rows being returned by the SELECT statement. We can, therefore, execute a single SQL statement to get both our rows, and row count — much better than submitting 2 SQL statements.

Once we have the results of our rows and row count, we have to sanitize it to remove the row count. And for that, dynamically resizing a list of dictionaries in Python is a lot faster than building a new one.

--

--