Building an efficient input data pipeline for deep learning.
Problem
When dealing with huge datasets of thousands/millions of files input data pipeline can be a game-changer or a bottleneck based on the design.
When datasets don’t fit in the RAM, using a python generator-based approach can become a huge bottleneck in training complex GPU compute-intensive models. GPU sitting idle waiting for data slow down model training by a huge margin.
That’s where Tensorflow TFRecord comes to the rescue.
TFRecord
TFRecord stores data in a sequence of protocol buffers serialized as a binary record. TFRecords files are very read efficient and light on hard disk space usage.
TFRecord does away with the need of reading each sample file from the disk for every epoch.
The efficiency of using TFRecord also comes with a cost, the requirement of writing a lot of complex code to create and read the records files.
Here we present a solution to this cost by introducing Datum to manage TFRecords files.
Datum is a library built on the top of TensorFlow to build an efficient fast input pipeline. Datum is designed to create/manage TFRecord datasets almost without writing complex codes.
Datum is designed to read/write TFRecord datasets and to build a fast input pipeline to be used for a single GPU or distributed training with just a few lines of code.
Datum makes use of tf.data
and tfrecord
build a fast input pipeline.
Installation
Datum can be installed from pypi
pip install datum
Export to TFRecord
Writing / exporting data to tfrecord format can get very complex if not using datum.
Datum makes it easier to export datasets to tfrecord
format. Datum provides a few predefined problem types to create a dataset with a few lines of code without getting deep into the working of tfrecord and serialization.
Import TFRWriteConfigs to define datum configuration for writing/exporting data to tfrecord
from datum.configs import TFRWriteConfigs
Define the splits information in the configs, splits names are important for datum to automatically identify the splits data.
write_configs = TFRWriteConfigs()
write_configs.splits = {
"train": {
"num_examples": <num of train examples in the dataset>
},
"val": {
"num_examples": <num of validation examples in the dataset>
},
}
Import the export API and problem type to convert the datasets for.
Different datasets are used for different purposes. For example, an image classification dataset with only a class label can not be used for image segmentation or image detection. To make conversion easier datum defines separate problem types for classification, detection, and segmentation tasks.
from datum.export.export import export_to_tfrecord
from datum.problem.types import IMAGE_CLF
Suppose we want to build a tfrecord dataset for an image classification task, type for that is IMAGE_CLF
Convert the dataset to tfrecord format
export_to_tfrecord(input_path, output_path, types.IMAGE_CLF, write_configs)
Datum will convert the dataset and save the output .tfrecord
files and dataset metadata files in the output path. The export tfrecord files can be easily loaded as tf.data.Dataset
using Datum load
API.
Load as
tf.data.Dataset
Import the loading API for loading the tfrecord dataset as tf.data.Dataset
from datum.reader import load
Loading dataset is very simple, just pass the output_path from the previous export state.
dataset = load(<path to tfreord files folder>)
train_dataset = dataset.train_fn('train', shuffle=True)
val_dataset = dataset.val_fn('val', shuffle=False)
Examples/cases in the dataset can be augmented before feeding into the model. It’s easy to preprocess and post-process samples in the dataset using pre_batching_callback
and post_batching_callback
.
pre_batching_callback: Using this callback example can be processed before batching.
post_batching_callback: Using this callback example can be processed after batching. Examples are processed as a batch.
Suppose we want to augment the dataset, which can be achieved using the following pre_batching_callback
def augment_image(example):
image = tf.image.resize(example["image"], IMG_SIZE)
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
example.update({"image": image})
return exampledataset_configs = dataset.dataset_configs
datset_configs.pre_batching_callback = lambda example: augment_image(example)
train_dataset = dataset.train_fn('train', shuffle=True)
Try Datum in practice
This notebook demonstrates the use of datum for transfer learning an EfficientNet-B0 model for an image classification project.
Play with the Transfer Learning with Datum notebook and use it to make your input pipeline fast.