Building an efficient input data pipeline for deep learning.

Mrinal Haloi
Subex AI Labs
Published in
3 min readMay 3, 2021
credit

Problem

When dealing with huge datasets of thousands/millions of files input data pipeline can be a game-changer or a bottleneck based on the design.

When datasets don’t fit in the RAM, using a python generator-based approach can become a huge bottleneck in training complex GPU compute-intensive models. GPU sitting idle waiting for data slow down model training by a huge margin.
That’s where Tensorflow TFRecord comes to the rescue.

TFRecord

TFRecord stores data in a sequence of protocol buffers serialized as a binary record. TFRecords files are very read efficient and light on hard disk space usage.

TFRecord does away with the need of reading each sample file from the disk for every epoch.

The efficiency of using TFRecord also comes with a cost, the requirement of writing a lot of complex code to create and read the records files.

Here we present a solution to this cost by introducing Datum to manage TFRecords files.

Datum

Datum is a library built on the top of TensorFlow to build an efficient fast input pipeline. Datum is designed to create/manage TFRecord datasets almost without writing complex codes.

Datum is designed to read/write TFRecord datasets and to build a fast input pipeline to be used for a single GPU or distributed training with just a few lines of code.

Datum makes use of tf.data and tfrecord build a fast input pipeline.

Installation

Datum can be installed from pypi

pip install datum

Export to TFRecord

Writing / exporting data to tfrecord format can get very complex if not using datum.

Datum makes it easier to export datasets to tfrecord format. Datum provides a few predefined problem types to create a dataset with a few lines of code without getting deep into the working of tfrecord and serialization.

Import TFRWriteConfigs to define datum configuration for writing/exporting data to tfrecord

from datum.configs import TFRWriteConfigs

Define the splits information in the configs, splits names are important for datum to automatically identify the splits data.

write_configs = TFRWriteConfigs()
write_configs.splits = {
"train": {
"num_examples": <num of train examples in the dataset>
},
"val": {
"num_examples": <num of validation examples in the dataset>
},
}

Import the export API and problem type to convert the datasets for.

Different datasets are used for different purposes. For example, an image classification dataset with only a class label can not be used for image segmentation or image detection. To make conversion easier datum defines separate problem types for classification, detection, and segmentation tasks.

from datum.export.export import export_to_tfrecord
from datum.problem.types import IMAGE_CLF

Suppose we want to build a tfrecord dataset for an image classification task, type for that is IMAGE_CLF

Convert the dataset to tfrecord format

export_to_tfrecord(input_path, output_path, types.IMAGE_CLF, write_configs)

Datum will convert the dataset and save the output .tfrecord files and dataset metadata files in the output path. The export tfrecord files can be easily loaded as tf.data.Dataset using Datum load API.

Load as tf.data.Dataset

Import the loading API for loading the tfrecord dataset as tf.data.Dataset

from datum.reader import load

Loading dataset is very simple, just pass the output_path from the previous export state.

dataset = load(<path to tfreord files folder>)
train_dataset = dataset.train_fn('train', shuffle=True)
val_dataset = dataset.val_fn('val', shuffle=False)

Examples/cases in the dataset can be augmented before feeding into the model. It’s easy to preprocess and post-process samples in the dataset using pre_batching_callbackand post_batching_callback .

pre_batching_callback: Using this callback example can be processed before batching.

post_batching_callback: Using this callback example can be processed after batching. Examples are processed as a batch.

Suppose we want to augment the dataset, which can be achieved using the following pre_batching_callback

def augment_image(example):
image = tf.image.resize(example["image"], IMG_SIZE)
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
example.update({"image": image})
return example
dataset_configs = dataset.dataset_configs
datset_configs.pre_batching_callback = lambda example: augment_image(example)
train_dataset = dataset.train_fn('train', shuffle=True)

Try Datum in practice

This notebook demonstrates the use of datum for transfer learning an EfficientNet-B0 model for an image classification project.

Play with the Transfer Learning with Datum notebook and use it to make your input pipeline fast.

--

--