Why I Chose WebDataset for Training on 50TB of Data?
I recently worked on a Distributed Training and MLOps project in which one of our goals was to make our training of a large model with huge data, highly scalable. This article will cover the data loading choices we made to achieve this.
Understanding of the Problem
Here are some of the points to help you understand the problem:
- 50 TB of image data (LAION-5B) would be used for training.
- Existing code was in PyTorch and for single GPU and single node setup.
- AWS was the cloud service used. The existing system required all the training data to be loaded to EFS and then processed.
- Consequently, 50TB of data on EFS called for high cost consumption and great storage requirements.
- The DataLoader loaded small jpg files one by one which added additional processing overhead.
Existing Data Loading Method
The existing DataLoader used a MapDataset to load all the files into memory and process it. This caused increased loading time and made parallelization limited and training less scalable.
For understanding the solutions implemented to alter this, we first need to understand the difference between Pytorch’s MapDataset and IterableDataset.
MapDataset
A PyTorch dataset that loads all the data into memory and applies a user-defined mapping function to each item in the dataset.
It implements the __getitem__() and __len__() protocols and represents a map from (possibly non-integral) indices/keys to data samples.
Here is a simple example code for a MapDataset data loader:
IterableDataset
A PyTorch dataset that reads data on-the-fly and defines an __iter__() method to return an iterator that yields the data items.
It implements the __iter__() protocol and represents an iterable over data samples. This type of dataset is particularly suitable for cases where random reads are expensive or even improbable and where the batch size depends on the fetched data.
Here is a simple example code for an IterableDataset data loader:
Issues with Using MapDataset
Memory Consumption
- Loads all the data into memory at once.
- Can lead to high memory consumption.
Slow Loading Times
- If your dataset is too large to fit into memory, you may need to load it from disk each time you access a sample.
- This can result in slow loading times, which can slow down your training process and increase training time.
I/O Bottleneck
- When working with large datasets, I/O (input/output) can become a bottleneck. MapDataset loads all the data into memory at once, which can result in a large number of I/O operations to read the data from disk.
- This can slow down your program and reduce performance.
Limited Parallelism
- MapDataset uses a single process to load data, which can limit the amount of parallelism you can achieve.
- This can be a problem when working with large datasets, as parallelism can help speed up data loading and improve performance.
Limited Scalability
- MapDataset may not be scalable to larger datasets or distributed environments. As the size of your dataset grows, it may become impractical to load all the data into memory at once.
- In distributed environments, you may need to load data from multiple nodes or partitions, which can be challenging with a MapDataset.
Using WebDataset for Data Loading
TAR (.tar) File Format
A tar file is a type of archive file that can contain multiple files within it.
The contents of the tar file are stored in a contiguous sequence of blocks.
Here is the order in which the files are compressed within the .tar file:
WebDataset
WebDataset is a Python package that provides a data loading and processing pipeline for working with large-scale datasets in PyTorch.
It uses a sharded, shufflable TAR (.tar) file format that is efficient for networked and distributed storage. It can efficiently load large .tar files containing thousands of images while providing options to decode the images on the fly.
WebDataset implements standard PyTorch IterableDataset
interface and works with the PyTorch DataLoader
.
DataLoader Using WebDataset
WebDataset provides several important options for reading and processing the data via .tar files on the fly. Options like decode can also be used to decode the images to the required format on the fly just in one line.
Here is a simple example of a data loader with IterableDataset using WebDataset:
Streaming in the Data from S3 Directly with Pipe Method
WebDataset provides a convenient way to stream in data from Amazon S3 by leveraging the pipe command. When data is requested from S3, it is returned in chunks, which can then be piped into the WebDataset pipeline for further processing.
Here’s how the URL should be revised and used for this method:
url = "s3://bucket/data.tar"
s3_url = f"pipe:aws s3 cp {url} -"
dataset = wds.WebDataset(s3_url)
This enables efficient loading and preprocessing of large datasets without the need to download the entire dataset to local storage, saving both time and storage space.
Additionally, WebDataset supports multi-threaded decoding, allowing for even faster data loading and preprocessing.
Nodesplitter for Data Distribution
WebDataset provides options to split data across nodes or workers with the help of nodesplitter. If there are 20 tar files (tar URLs) and 2 GPUs, nodesplitter will distribute all these URLs across the two GPUs so that the data is distributed and there is no duplication.
Here is how the data is split by worker in the code:
dataset = wds.WebDataset(s3_url, nodesplitter=wds.split_by_worker)
This distribution happens at the shard level, not at the sample level. If sample level distribution is needed, additional workarounds need to be implemented inside the data_generator method in the data loader.
The logic would roughly be something like this assuming images inside the .tar file are named as 1, 2, 3, … :
num_workers = int(os.environ["WORLD_SIZE"])
rank = int(os.environ['RANK'])
img_obj = next(iter(iter_dataset))
img_val, img_id = img_obj
if int(img_id) % num_workers != rank:
continue
else:
yield img_val
How Is the Data Loader with WebDataset Better?
- Better data loading performance as it allows for lazy-loading of data, which means that data is loaded on-demand as needed, rather than loading all data into memory upfront.
– This reduces memory usage and allows for larger datasets to be processed. - Support for shuffling of data without having to load all data into memory, which can be a challenge for large datasets.
– This is accomplished by shuffling the tar file shards before loading them, which ensures that the data is shuffled across all samples. - Allows for parallel processing of data, which can significantly speed up the data loading process.
– It achieves this by utilizing Python’s multiprocessing module to parallelize the data loading across multiple CPU cores. - Cost effective method for loading large amount of data from AWS S3.
Some of the Alternate Options
- Use s3Dataset for IterableDataset with s3 and stream input data samples. (Lacks some options WebDataset provides and is said to be not fully compatible with Pytorch’s Distributed Training Framework)
- Use FSx Lustre to load the data from s3 and use Sagemaker’s FastFile input mode to stream in data. (Too costly)
- Use MapDataset or IterableDataset with Sagemaker’s FastFile mode directly. (Requires Sagemaker setup)
Conclusion
When training on a huge amount of data, say 50TB of image data, the code setup should be highly scalable and efficient to achieve maximum speed and reduced costs.
For this, one of the key things is to choose how we load and input the data for training. This article explains how an IterableDataset can be a better data loading choice than a MapDataset we normally use. The article also covers how WebDataset assists in further improving the training performance.
The options chosen are based on the problem requirements and on what we thought to be the best options after experimenting. Feel free to share your knowledge in the responses or share any suggestions you may have on this topic. I’d be happy to learn from you :)
For More on WebDataset
Articles on Python Programming
Articles on Data Science and Machine Learning
Articles on Desktop App Development (Electronjs)
- Integrating Python Flask Backend with Electron (Nodejs) Frontend
- Electron Builder: Packaging Electron (Nodejs) Application (along with Flask app) for Windows
- How to Create a Splash Screen for Electron App
About Ahmad Sachal
- LinkedIn Profile — linkedin.com/in/ahmad-sachal/
About Red Buffer
- Website — redbuffer.ai
- LinkedIn Page — linkedin.com/company/red-buffer