Speeding up Keras with tfrecord datasets
Save yourself some time by removing bottlenecks
Alright. So you just got started with Keras with Tensorflow as a backend. Introducing GPU computing was quite simple so you started increasing the size of your datasets.
Everything works fantastic, your GPU is happy and hungry for more, so you increase the dataset size even more to improve the robustness of your model.
At a certain size, you hit the limit of your RAM and naturally you write a quick python generator to feed your data directly into the Keras model. That’s when you recognize the performance hit you just took. GPU Usage settles at around 20 percent, which means that most of the time your GPU is just waiting for your data to get loaded by your python generator.
That’s where tfrecords enter the game and saves the day. But be warned! Speeding up your training time may lead to a decrease in coffee breaks!
A tfrecord dataset is basically your dataset saved as a protocolbuffer on your hard drive. The benefit of using this format is:
- You do not need to load the complete dataset into memory. You can ingest your data piece by piece thanks to the dataset class. Tensorflow takes care of loading more data while your GPU is crunching your numbers.
- Its blazing fast since you do not need to load your data into a numpy array first and then ingest it back into your keras/tensorflow session. You just stay C++ end to end.
To build your own input pipeline you need to do the following steps.
- Convert your dataset into a TFRecord dataset and save it to your disk.
- Load this dataset using the TFRecordDataset class
- Ingest it into your Kerasmodel.
Convert to TFRecord dataset
Creating your dataset is pretty straightforward. You can follow this guide for a more detailed instruction. All you need to do is to define your dataset using something like this:
Try to create small Datasets which are not bigger than your RAM but big enough that the serialization of tfrecords gives you an advantage. “Relatively” small file sizes give you the advantage on shuffling the data on read and perform other cool tricks on read.
I try to aim for files sizes of 4–5 GB (on a 32 GB RAM Maschine) to make the files more “user-friendly” when copying them around. One file normally includes one single category. This gives me the flexibility to add and remove categories quite fast.
Load the dataset for training
Here a small GIST to show you how to load your freshly created tfrecord as a tf dataset.
Note that you are not loading the data directly. You are just building a pipeline. The variables
image are just tensors which get populated during a tensorflow session later on.
Ingest into Kerasmodel
The last step is quite simple. You load create your image and label tensor and afterward create an Input layer for your keras model. This layer gets created by handing over the image. During compilation of the model, you hand over the
target_tensors as well in a similar pattern.
The most tricky part is, that Keras does not know how many steps one Epoch takes. During loading, we told tensorflow to repeat the dataset forever. Remember? (Hint: Line 27)
I normally count the numbers of data points per file and write them into an .txt file. Dividing the sum of your data samples by the batch sizes gives you the steps per epoch.
That’s pretty much all there is to do to get your Neural Network on Keras going with a tfrecord dataset.
Tfrecords are a great way to improve and clean up your data loading. It improves the speed of your data reads and makes it possible to keep your classes in categorized files.
It is even possible to write data filters or data transformers with tensorflow on read (e.g. flipping or rotating images, adding noise, blocking bad data). This gives you the advantage that your raw data stays untouched on your drive and you do not have to worry about 10 different versions of your raw dataset eating away your precious dataspace.
Thank you for reading.
If you have anything to add to this article or know a better way to organize data loading with keras feel free to leave a comment. I am happy to learn an even easier solution :)