PyTorch Dataset, DataLoader, Sampler and the collate_fn

Stephen Cow Chau
Apr 4 · 6 min read

Intention

There have been cases that I have some dataset that’s not strictly numerical and not necessary fit into tensor, so I have been trying to find a way to manage my data loading beyond passing the input to PyTorch Dataloader object and let it automatically sample the batches for me, and I have been doing that multiple times and so I would like to study a bit deeper and share it here as a record for my future reference.

Main Reference

PyTorch official reference:

Main Classes / function(s)

Dataset (and their subclasses)

This is not always necessary, especially our dataset normally are in form of list, Numpy array and tensor-like objects, This is because the DataLoader can wrap your data in some sort of Dataset.

What a Dataset object does?

It’s considered the object to encapsulate a data source and how to access the item in the data source.

What occasion would I create a custom dataset?

For some of my scenarios, the data are from multiple sources and need to be combined together (like multiple csv files, database), or data transform can be applied statically before iterating by data loader.

What are the 2 types of datasets mentioned in the document?

There are, according to documentation, 2 types of Dataset, one is iterable-style and the other is map-style.

In the document, it says iterable-style Dataset would implement __iter__() while the map-style Dataset would implement __getitem__() and __len__().

One can reference some official sample of implementing both type of dataset:

On the other hand, the documentation explicitly mentioned for the iterable-style datasets, how the data loader sample data is up to implementation of __iter__() of the dataset, and does not support shuffle, custom sampler or custom batch sampler in dataset.

Before I explore more on the difference, it would be worth looking into how data loader sample data.

But note that from documentation:

DataLoader

This is main vehicle to help us to sample data from our data source, with my limited understanding, these are the key points:

  1. Manage multi-process fetching
  2. Sample data from dataset as small batches
  3. transform data with collate_fn()
  4. pin memory (for GPU memory performance)

How does DataLoader sample data?

High level idea is, it check what style of dataset (iterator / map) and iterate through calling __iter__() (for iterator style dataset) or sample a set of index and query the __getitem__() (for map style dataset)

Sampler

Define how to samples are drawn from dataset by data loader, it’s is only used for map-style dataset (again, if it’s iterative style dataset, it’s up to the dataset’s __iter__() to sample data, and no Sampler should be used, otherwise DataLoader would throw error)

What Sampler does actually?

It would generate a sequence of indices for the whole dataset, consider a data source [“a”, “b”, “c”, “d”, “e”], the Sampler should generate an indices of same length as dataset, for example [1,3,2,5,4].

BatchSampler

BatchSampler objective is to take in a Sample object (which have an __iter__() to return the indices sequence), and prepare how to generate batches of indices.

using the same example above, if the __iter__() of the sampler is returning [1,3,2,5,4], the default implementation would break the indices sequence into batch_size, let say is 2, then it would return [ [1,3] , [2,5] , [4] ] (note the last item [4] is returned assuming the “drop_last” parameter of data loader is False)

The data loader would take this batch indices sequence and draw sample in batch by batch and that would yield [“a”, “c”] | [“b”, “e”] | [“d”]

Without discussing on collate_fn, the process is expected to be something like this

collate_fn()

This is where transform of data take place, normally one does not need to bother with this because there is a default implementation that work for straight forward dataset like list.

default collate_fn

What the default collate_fn() does, one can read implementation of this source code file.

Look at a few examples to get a feeling, note that the input to collate_fn() is a batch of sample:

For sample 1, what it does is to convert the input to tensor

For sample 2, the batch is a tuple of 2 lists, and it return a list of tensor, which each tensor get 1 item from each list in original tuple

For sample 3 and 4, the input look like typical data form that have multiple attributes. Consider case 4, if 3rd element per record is the label and first 2 elements are input data attributes, the return list of tensors is not directly usable by the model, in which the preferable return could be:

Site Note: for pandas DataFrame, the dataloader would massage the data into a list through the fetch function in _MapDatasetFetcher class, so we could treat it as list sample as well.

possiblely_batched_index is a list like [1, 3]

The PyTorch documentation give following use cases:

The first example “collating along a dimension other than the first”, my interpretation is when you want the batch data being grouped differently compare to default collate function.

Implementations can be as follow:

the 1st use list and 2nd one use tensor slicing.

I believe that’s the most common use case to define a custom collate_fn()

For 2nd example of padding sequence, one of the use case is RNN/LSTM model for NLP. For a batch of sentence, when we sample randomly, we would get batches of sentence with different length, and because we are performing batch operation, we would need to pad the shorter sequences to the longest one. One option is to pad to a pre-defined maximum length, it should be the case for Transformer models, but in the old days, when using RNN/LSTM, reducing the number of pad would be preferred as it save the processing time the model is running over non-meaningful pad tokens.

There could be one more use case I might consider putting code in collate_fn, as below example, I convert the text sentence into Transformer expected batch input inside the collate_fn.

This is not the only way to do this, one can keep text data in Dataset, and once data loader return, process the data before passing to the model

This is a matter of choice, but there is one potential implication, which is performance. Because data loader support multiprocess through multiple workers, that means the code in collate_fn() can naturally enjoy the multi-worker performance speed up.

Final words

Writing this article is fulfilling and yet not so enjoyable, the fulfilling part is on exploring more in depth for the whole data loading pipeline and the thinking process of how to implement the logic in different part of the code. The not so enjoyable part is the prolonged period of time and unable to finish the article as I keep seeing more to write and explore.

Hope this help my future self as well as some other people

If the content have something wrong or you have better suggestion, please feel free to let me know.

Proud to geek out. Follow to join our +1.5M monthly readers.