A guide to custom DataGenerators in Keras

Learn how you can use custom DataGenerators to create batches of data on-the-fly

Robin Vinod
The Startup
3 min readMay 1, 2020

--

Data generators are a way to provide batches of data on-the-fly while training a neural network. It becomes incredibly useful when dealing with large datasets, where the entirety of data cannot be stored in memory. In this story, I go through the process of making your own custom data generator in Keras.

Outline

  1. Where can I use data generators?
  2. How does a data generator work?
  3. Full code

1. Where can I use data generators?

Keras provides 3 methods in keras.utils.sequence that can use a custom data generator.

These methods take in a generator as an argument rather than a dataset stored in memory. The model will fit using the data supplied by the generator.

2. How does a data generator work?

a. DataGenerator Class

While you can make your own generator in Python using the yield keyword, Keras provides a keras.utils.sequence class that you can inherit from to make your custom generator. Inheriting Sequence ensures that the network will only train once on each sample per epoch, making it a safer way to do multiprocessing.

list_IDs is an array containing the filenames/paths of all the datasets.

The on_epoch_end() method is called at the very beginning of training, and at the end of every subsequent epoch. It creates self.indexes which is simply an array containing the indexes of all the files listed in list_IDs. If shuffle is set to true, the indexes are randomised such that the files used to make up a batch is different at every epoch.

Every Sequence generator needs to implement the __len__() and __getitem__() methods.

The __len__() method is used to calculate the total number of possible batches, to ensure that each batch is seen at most once per epoch. As a general rule, it is good to make sure that :

Given that you have 1024 sets of data, and a batch size of 2, __len__() will return 512, meaning that there is 512 unique batches available in total. Each batch is given an index from 0–511, which is used in the __getitem__() method.

The __getitem__() method is called whenever the model needs a new batch of data. If the requested batch has an index of 312, the method will extract datasets from list_IDs that have the index 312*2 to (312+1)*2.

Essentially, a batch size number of datasets is extracted from list_IDs using the given batch index as a starting lookup.

A new array, list_IDs_temp is created that contains the filenames/paths of the datasets for the current batch, which is then given to the __data_generation() method where these datasets will be loaded into memory.

An empty placeholder array with the appropriate dimensions is created, which will be populated after the for loop is done.

The logic to load datasets is housed within the for loop. An example is shown here:

3. Full code

When utilising the generator, make sure that these two arguments are enabled. Using a higher number of workers will use more CPU threads to generate more batches of data in parallel, ensuring that data generation is not a bottleneck, especially when training with multiple GPUs.

--

--

Robin Vinod
The Startup

17 year old student interested in ML research and application development. I use Medium to share what I’ve learnt. https://www.linkedin.com/in/robin-vinod/