Using asyncpg with pytorch

Jonathan Wickens
The Startup
Published in
2 min readMay 2, 2020

tl;dr don’t use pytorch Dataset and Dataloader, but with a bit of async magic you can create a pytorch-compatible alternative.

One of the first things you will find out when running machine learning experiments is that if you can’t feed data to your GPU fast enough, training will be aggravatingly slow.

This is why pytorch provides Dataloader. At the surface it merely collates several pieces of data into a batch. However it also provides multithreaded workers to load data more quickly. If you’re working with say images, using pytorch’s built in Dataloader and torchvision.ImageFolder work great for this.

Using asyncpg we can effectively use multiple threads by querying all the samples in our batch simultaneously. We can do this by using asyncio.gather which processes multiple coroutines simultaneously. Assuming “start_points” for all the samples in the batch and a coroutine called “get_sequence” (that gets a single sample) that looks like this:

tasks = map(get_sequence, start_points)
batch = await asyncio.gather(*tasks)

We get these start points by using postgres’s own built in sample clause TABLESAMPLE in order to randomly shuffle samples from a pre-partitioned table for test or training data.

SELECT
FROM test_samples
TABLESAMPLE SYSTEM (100)

Where 100 is the percent. To make getting these start points efficient this query only runs once with the dateset iterator is created (for a single epoch). It stays on the postgres server using a server side cursor until we have finished the epoch.

If we combine using asyncpg with these posgres tricks into an async generator we get something like the following:

async for start_point in self.iter_start_point_cursor:
start_points.append(start_point)
if len(start_points) == self.batch_size:
batch = await map_start_points()
yield batch

Once we have our async generator we need to make it a normal sync iterator for compatibility with pytorch libraries. This involves a bit of magic, which I mostly just copied off stack-overflow. You can see it being used and the magic itself in the code below.

The point is that using this method I was able to get retrieving a single batch down from around 30 seconds with a naive approach retrieving samples sequentially vs less than 1 second using a multiple connections simultaneously.

Finally the code:

Here are some other things that I found helped:

  • NamedTupleis faster than TypedDict
  • Indexes on ID columns are faster than indexes on text / date columns

How do you typically use SQL for ML projects? Are there any libraries that make async iterator use-cases easy out of the box? What are your best practices for loading datasets?

--

--