Training Faster With Large Datasets using Scale and PyTorch

PyTorch
PyTorch
Published in
7 min readMar 31, 2020

Authored by Daniel Havir & Nathan Hayflick at Scale AI.

Scale AI, the Data Platform for AI development, shares some tips on how ML engineers can more easily build and work with large datasets by using PyTorch’s asynchronous data loading capabilities and Scale as a continuous labeling pipeline.

When training models for production applications, ML engineers often need to iterate on their training datasets as much as they iterate on their model architecture. This is because increasing the size, variety or labeling quality of a dataset is one of our most powerful levers to increase model performance¹. Ideally our machine learning pipeline should make acquiring new data and retraining as quick and seamless as possible to maximize these performance gains. However engineers face several barriers to retraining on larger datasets including the complexity of managing large-scale data labeling operations and the time (as well as compute cost) of retraining on more data. This article demonstrates some techniques for managing the kinds of issues specific to working with large datasets by using Scale’s labeling platform and PyTorch’s rich support for custom training workflows.

Data Labeling and Annotation

The first challenge to training on larger datasets is simply obtaining large quantities of annotations without sacrificing quality. While smaller datasets can be labeled successfully one-off by a few labelers (or even ML engineers themselves), building datasets composed of hundreds of thousands of scenes require a large amount of QA processes and tooling. Ensuring labeling rules are followed consistently across a large and diverse group of labelers is key for improving accuracy in training. Scale’s Labeling API was designed to abstract away many of these concerns so whether you are collecting a small prototype training set or millions of labels the process is as painless as possible.

We start with a collection of unlabeled images of indoor scenes that we’ve collected for our project, like the following example:

Example indoor scene for segmentation

In this example, we show how to train a model to identify pieces of furniture, appliances and architectural features on increasingly larger datasets to improve model performance. This model is similar to the types of computer vision solutions we see being deployed in retail, real estate or e-commerce. We start by establishing in what format we need our annotations. Our labels must match the desired output format for our model. While there are several different common output formats for computer vision models (including boxes and polygons), for this task, we’ve chosen to use a semantic segmentation model. This format allows our model to provide the greatest amount of environmental context — context that comes at the expense of greater computational resources (more on this later).

Scale provides a Semantic Segmentation endpoint where we can submit our unlabeled images along with labeling instructions and our desired labeling structure.

Our label mapping hierarchy must encompass all the classes of objects we are interested in predicting, such as furniture, lighting, etc. While Scale supports generating panoptic labels (where we store a separate pixel map for each distinct object within a class), we will be using a classic semantic segmentation approach here.

Data Labeling Pipeline

Once Scale has completed the labeling process, a task response will be sent to our designated callback url with links to a set of PNGs. These PNGS contain the combined, indexed and individual labeling layers that we can use to train our model in the following steps.

Scene segmentation — each color represents a label layer.

Getting Started With Local Training

Now that we are receiving data from our labeling pipeline, we can train a prototype model with a simple implementation that we can run locally or on a single notebook instance. We want to start with a smaller subset of our dataset to speed prototyping and benchmark performance before increasing the quantity of training data in later steps.

By setting up our ML environment with PyTorch, we can leverage an ecosystem of popular model architectures to start training on. Here we import the DeeplabV3 model from the TorchVision package and then perform some pre-processing on our Scale formatted labeled validation set before loading data into the model.

One thing to keep in mind, your production dataset will change over time as more data is labeled, so it is important to keep track of your train/test split in order to evaluate your models consistently on the same examples. An easy way to achieve a consistent split is to use a hashing function. This guarantees that all of your test examples remain in your test set even while your dataset continues to grow.

Now we can run our validation tests and confirm the model’s training is working!

Streaming Data With PyTorch

At this point, it’s common to begin diving into the architecture of the model or fine-tuning. Instead, we are going to take a different approach and see how far we can get by simply increasing the size of our training dataset. One challenge we will have to address is that our training code is currently slow and memory intensive due to the size requirements of our segmentation data (each labeled scene is a full-size PNG). Spinning up a development environment also becomes increasingly painful if we are required to preload gigabytes of data before running. To continue to expand the size of our dataset we will need to modify our approach to work outside the bounds of a single machine.

The flexibility of PyTorch’s DataLoader class enables us to implement data streaming fairly easily by customizing how our Dataset class loads items. This allows us to move our dataset from disk to cloud storage (GCS or the like) and progressively stream our labels as we train.

When we need to run locally, we can use pil_loader as the loader function and then switch to gcs_pil_loader for streaming.

Moving to a streaming architecture also allows us to save on training costs by migrating to preemptible instances (Scale’s ML team uses Kubernetes which integrates nicely with preemptible instances on most cloud providers). This gives us access to cheaper, on-demand hardware as long as our workflows are fault-tolerant to instance shutdowns.

It also allows us to easily experiment with different samples of our dataset by interacting with our data as a set of indexed URLs in a storage bucket (preventing us from needing to transfer files in bulk everytime we want to change the distribution).

We can see this approach scales much better, allowing us to keep our data in one central place without needing to download massive datasets to an instance every time we start training — our benchmarks show an average data load time of 25.04ms. To speed things up here, we may want to consider an additional modification: asynchronous streaming.

Asynchronous Streaming

To further ease working with a larger dataset, we can minimize data loading time by moving to an async workflow.

Asynchronous Training Pipeline

In our previous streaming example, training speed was bottlenecked by the speed of each data request (which introduces significant latency when loading training data from a cloud object store such as GCS or S3). This latency can be negated using concurrency, i.e. latency hiding. We evaluated several options for concurrency in Python including multithreading, multiprocessing and asyncio. The Python libraries for both S3 and GCS block the global interpreter lock, so gains are small with multithreading. Multiprocessing gives us true concurrency — however this comes at the expense of memory usage and in our tests we frequently ran into memory-constraints when we tried to scale up the number of processes in production. In the end, we landed on asyncio as the best solution for highly-concurrent data loading. Asyncio is suitable for IO-bound and high-level structured network code. DataLoader already achieves some concurrency using PyTorch’s multiprocessing, however for the purpose of network latency hiding, we can achieve higher concurrency at a lower cost using asyncio. We use both multiprocessing and asyncio in production, with each process running its own asyncio event loop. Multiprocessing gives us concurrency of CPU-intensive tasks (such as conversion to torch tensor or data augmentation) while the event loops hide the latency of hundreds of GCS/S3 requests.

This requires some rewriting of our code to handle concurrency.

Looking at some training benchmarks, we can see that synchronous streaming triggers an image loading bottleneck that causes the average batch time to spike to almost 20 seconds (when using the largest batch size), while the async streaming code took less than a second on average. Training becomes much more efficient since we are able to keep the GPUs saturated even as our batch sizes increase.

Data loading benchmarks

Before finishing up, we will want to add a series of data augmentation techniques to artificially increase the size of our dataset even further. Some color jitter, random affine transforms and random cropping are all good starting points.

Wrapping up: A system for constant dataset iteration

Through this article we’ve moved from a prototype tied to a small static dataset to a more flexible training system. We can now swap out and experiment with increasingly larger datasets (provided by our labeling pipeline with Scale) while minimizing retraining time and cost. In the future we can make our distributed training workflow even more efficient with the recently released PyTorch Elastic framework which makes it even easier to work with a dynamic pool of spot instances. By integrating Scale’s labeling platform and PyTorch’s tooling for distributed training workflows, ML engineers have a powerful set of building blocks to continuously expand our selection of training data and repeatedly drive performance gains in their ML models.

References

¹ https://developers.google.com/machine-learning/data-prep/construct/collect/data-size-quality

--

--

PyTorch
PyTorch

PyTorch is an open source machine learning platform that provides a seamless path from research prototyping to production deployment.