Dataset Streaming in PyTorch

Building an Efficient Pipeline for Datasets that do not fit in RAM

Alexander Wei
6 min readJan 6, 2024
Image by author

In the realm of machine learning, managing large datasets efficiently is often a critical task. PyTorch, known for its flexibility and ease of use, offers robust tools for this purpose. This article aims to guide you through constructing a data pipeline that not only manages memory efficiently by streaming data from the hard drive but also integrates seamlessly with PyTorch’s DataLoader and models.

Understanding the Challenge

Working with large datasets in memory can be a daunting task. The traditional approach of loading the entire dataset into RAM is often impractical due to memory constraints. An alternative, more memory-efficient approach involves streaming data directly from the filesystem. This method significantly reduces the memory footprint, making it possible to handle larger datasets effectively.

Without streaming, the size of a dataset is limited to RAM. Image by author

Setting Up the Data Pipeline

The journey begins with preprocessing data into unpacked lists of NumPy arrays, a common format for numerical data manipulation. Here, we use joblib, a tool that offers efficient disk I/O operations, to save our preprocessed data to disk. This step is crucial for setting the foundation of our efficient data pipeline. Splitting this into a few convenience functions makes for versatile preprocessing:

from joblib import numpy_pickle, dump
import pandas as pd

def split_iter(a: list, n: int):
"""Pack a dataset (array of samples) into an array of batches"""
q = len(a) // n - 1
assert q > 0
k, m = divmod(len(a), q)
for i in range(q):
yield a[i*n:(i+1)*n]

def batches(data, batch_size=36):
iterator = split_iter(data, batch_size)
return iterator

def save(x, y, batch_size=36, dest="."):
in_batches = batches(x, batch_size)
out_batches = batches(y, batch_size)
joblib.numpy_pickle.dump(list(out_batches), "./out/y_%d.job" % batch_size)
joblib.numpy_pickle.dump(list(in_batches), "./out/x_%d.job" % batch_size)

In turn it is quite simple to convert numpy arrays of features and labels into batch-streaming format:

import numpy as np

# Dummy dataset generation
number_of_samples = 1000
number_of_features = 10
number_of_classes = 5

features = np.random.randn(number_of_samples, number_of_features)
labels = np.random.randint(0, number_of_classes, (number_of_samples,))

save(features, labels, batch_size=4)

The next step is to build the training component of the pipeline.

Creating a Custom Dataset in PyTorch

PyTorch’s custom Dataset class comes into play here. This class is tailored to load data on-demand, thereby conserving memory. Creating a custom dataset involves defining how data is loaded and indexed, paving the way for efficient data handling.

import joblib

class Dataset(TorchDataset):
"""Dataset for training reads from joblib pickle files"""
def __init__(self, *feature_file):
self.feature_file, self.label_file = feature_file

def __len__(self):
"""returns the total number of batches"""
return len(joblib.load(self.feature_file))

def __getitem__(self, idx):
"""loading a batch on the fly"""
X = joblib.load(self.feature_file)[idx]
Y = joblib.load(self.label_file)[idx]

return torch.from_numpy(X).float(), torch.from_numpy(Y).float()

Implementing DataLoader for Efficient Data Streaming

The DataLoader in PyTorch is a versatile tool for batching and shuffling data. It becomes especially powerful when dealing with large datasets, as it enables data streaming from disk in manageable batches, thus maintaining a low memory profile. The num_workers and prefetch_factor parameters determine the size of the streaming buffer and should be tuned according to your machine. If the dataset is indexed over batches before saving, as in the example above, it is important to set batch_size=1 here, since this will fetch a tuple of length one with each iteration. In turn, each one of these is unpacked into a tensor containing fully the batch of size 4.

from torch.utils.data import DataLoader

dataset = Dataset("./out/x_4.job", "./out/y_4.job")

data_loader = DataLoader(dataset,
batch_size=1,
shuffle=True,
num_workers=3,
prefetch_factor=64)

Building a Simple PyTorch Model for Demonstration

We have built a dataset module for streaming, but have yet to define a model! To demonstrate the integration of our data pipeline with PyTorch, we construct a basic model with a single linear layer. This model serves as a testbed to illustrate how seamlessly the DataLoader interfaces with PyTorch models, ensuring a smooth training process.

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import joblib

class SimpleModel(torch.nn.Module):
def __init__(self, input_size, output_size):
super(SimpleModel, self).__init__()
self.input_size, self.output_size = input_size, output_size
self.linear = torch.nn.Linear(input_size, output_size)

def forward(self, x):
return self.linear(x)

# Assuming input and output sizes
model = SimpleModel(input_size=number_of_features, output_size=1)

batch_size = 4

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Example of using DataLoader and Model
num_epochs = 100
for epoch in range(num_epochs):
for batch_idx, (features, labels) in enumerate(data_loader):
# Here, features and labels are streamed from the disk per batch
batch = features.view(batch_size, model.input_size)
output = model(features)
loss = criterion(output.view(batch_size, model.output_size),
labels.view(batch_size, model.output_size))

# Backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()

if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}', flush="")

Putting It All Together

Now, we combine all the elements: preprocessing, custom dataset creation, DataLoader implementation, and the PyTorch model. This concise example demonstrates the pipeline’s functionality and provides a template for adaptation to various datasets and models.

"""Streaming Data Loader and Simple PyTorch Model"""
import numpy as np
from torch.utils.data import Dataset as TorchDataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torch
import joblib

# ***PREPROCESSING - 1/3***
number_of_samples = 1000
number_of_features = 10
number_of_classes = 5

features = np.random.randn(number_of_samples, number_of_features)
labels = np.random.randint(0, number_of_classes, (number_of_samples,))

def split_iter(a: list, n: int):
"""Pack a dataset (array of samples) into an array of batches"""
q = len(a) // n - 1
assert q > 0
k, m = divmod(len(a), q)
for i in range(q):
yield a[i*n:(i+1)*n]

def batches(data, batch_size=36):
iterator = split_iter(data, batch_size)
return iterator

def save(x, y, batch_size=36, dest="."):
in_batches = batches(x, batch_size)
out_batches = batches(y, batch_size)
joblib.numpy_pickle.dump(list(out_batches), "./out/y_%d.job" % batch_size)
joblib.numpy_pickle.dump(list(in_batches), "./out/x_%d.job" % batch_size)

save(features, labels, batch_size=4)

# ***LOADING - 2/3***
class Dataset(TorchDataset):
"""Dataset for training reads from joblib pickle files"""
def __init__(self, *feature_file):
self.feature_file, self.label_file = feature_file

def __len__(self):
"""returns the total number of batches"""
return len(joblib.load(self.feature_file))

def __getitem__(self, idx):
"""loading a batch on the fly"""
X = joblib.load(self.feature_file)[idx]
Y = joblib.load(self.label_file)[idx]

return torch.from_numpy(X).float(), torch.from_numpy(Y).float()

dataset = Dataset("./out/x_4.job", "./out/y_4.job")

data_loader = DataLoader(dataset,
batch_size=1,
shuffle=True,
num_workers=3,
prefetch_factor=64)


class SimpleModel(torch.nn.Module):
def __init__(self, input_size, output_size):
super(SimpleModel, self).__init__()
self.input_size, self.output_size = input_size, output_size
self.linear = torch.nn.Linear(input_size, output_size)

def forward(self, x):
return self.linear(x)

# ***TRAINING - 3/3***
model = SimpleModel(input_size=number_of_features, output_size=1)

num_epochs = 100
batch_size = 4
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(num_epochs):
for batch_idx, (features, labels) in enumerate(data_loader):
output = model(features)
loss = criterion(output.view(batch_size, model.output_size),
labels.view(batch_size, model.output_size))
optimizer.zero_grad()
loss.backward()
optimizer.step()

if (epoch+1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}', flush="")

Performance Considerations and Best Practices

Optimization of our data pipeline is essential for tailoring it to specific datasets and models, ensuring peak performance. A critical aspect of this optimization in PyTorch revolves around the DataLoader’s configuration, particularly the number of fetch workers and the pre-fetch factor. The number of workers (num_workers) determines how many parallel processes will load the data, thereby impacting the speed of data retrieval and overall throughput. Choosing the right number of workers is crucial; too few and you under-utilize your system’s capabilities, too many and you might encounter bottlenecks in data processing. Similarly, the pre-fetch factor (prefetch_factor) in the DataLoader dictates how many batches are preloaded before being processed. This setting is vital for maintaining a consistent and efficient data flow, especially in scenarios where data processing is computationally intensive. Balancing these parameters is key to streamlining your data pipeline, reducing idle time, and maximizing the efficiency of your machine learning model in PyTorch.

Conclusion

Building an efficient data pipeline in PyTorch is a valuable skill in the arsenal of any machine learning practitioner. By following the steps outlined in this article, you can handle large datasets with ease, ensuring your models are both scalable and efficient.

The complexities of training on extensive datasets are manifold. Consider the challenges in diverse fields: Image processing grapples with behemoths like ImageNet, a repository of millions of high-resolution images; natural language processing contends with colossal text corpora such as the entirety of Wikipedia; financial market analysis constantly processes an unending stream of transaction data; and genomics confronts the Herculean task of deciphering extensive genomic sequences. Each of these scenarios underscores the critical need for adept data handling.

In another article, we take this data streamer for a spin on a dataset of food recipe ingredients. Although millions of Unicode-text samples fit easily into 1Gb of RAM, their floating point embeddings do not.

--

--