Sitemap
TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

How to (quickly) Build a Tensorflow Training Pipeline

8 min readOct 26, 2018

--

Tensorflow is great. Really, you can do everything imaginable. You can build really cool products with it. However, Tensorflow’s code examples generally tend to gloss over how to get data into your model: they either sometimes naively assume that someone else did the hard work for you and serialized the data into Tensorflow’s native format, or showcase unreasonably slow methods that would have a GPU idling away with shockingly low performance. Also oftentimes the code is very hacky and difficult to follow. So I thought it might be useful to show a small, self contained example that handles both training and efficient data pipelining on a nontrivial example.

3 Ways to Feed Tensorflow Models with Data: how to get it right?

There are three paths to enlightenment. Or at least to feeding Tensorflow models with data.

In typical Tensorflow fashion, there are many ways that you could get your data pipeline set up. This guide will quickly list the top 3, and show you how to use a compromise that gets you that go-to solution that is very easy to code and blazingly fast for the for 80% of the use cases.

Generally Speaking, there are 3 ways in which you can get data into your model:

  1. Use a feed_dict command, where you override an input tensor using an input array. This method is widely covered in online tutorials, but has the disadvantage of being very slow, for a bunch of reasons. Most notably, to use feed_dict you have to load the data into memory in python, which already negates the possibility of multithreading since python has this ugly beast called a global interperer lock (aka GIL). Using feed_dict is widely covered in other tutorials, and is a fine solution for some cases. As soon as you try to utilize a high capacity GPU however, you’ll find that you’re straining to utilize even 30% of its compute power!
  2. Use Tensorflow TfRecords. I feel here I can go on a limb and say outright that 9 time out of 10 it’s just a bad idea to get into this mess. Serializing records from python is slow and painful, and deserializing them (i.e. reading into tensorflow) is equally an error prone, coding intensive affair. You can read about how to use them here.
  3. Use tensorflow’s tf.data.Dataset object. Now, that’s a great way to go, but there’s so many different ways to use it, and Tensorflow’s documentation doesn’t really help in building non-trivial data pipelines. This is where this guide comes in.

Let’s take a look at real-life use case, and build a complex data pipeline that trains blazingly fast on a single machine with potentially high capacity GPU.

Our Model: Basic Face Recognition

So let’s deal with a concrete example. Let’s imagine that our goal is to build a Face Recognition model. The input to the model are 2 images, and the output is 1 if they’re the same person, and 0 otherwise. Let’s see how a super-naive Tensorflow model might approach this task

Alright, so this model isn’t going to win any awards for best face recognition in history. We’re just taking the difference between the two images, an feed this difference map through a standard conv-relu-maxpool neural net. If this is gibberish to you, don’t worry: it’s not a a very good approach to compare images anyway. Just take my word for it that it would be at least somewhat capable in identifying photos that are of the same person. All the model needs now is data — which is the point of our fun little post.

So what does our data look like?

The Data

A classic (tiny) dataset for face recognition is called labeled faces in the wild which you can download here. The data is quite simple: you got a bunch of folders, and each folders contains photos of the same person, like so:

/lfw
/lfw/Dalai_Lama/Dalai_Lama_0001.jpg
/lfw/Dalai_Lama/Dalai_Lama_0002.jpg
...
/lfw/George_HW_Bush/George_HW_Bush_0001.jpg
/lfw/George_HW_Bush/George_HW_Bush_0002.jpg
...

One thing we could do is generate all the pairs imaginable, cache them, and feed them to the model. That would be highly and take up all of our memory, since there are 18,984 photos here, and 18,894 squared is… a lot.

So, let’s build a very lightweight pythonic function that yields a pair of photos, and indicates whether they’re the same person — and samples another random pair at each iteration.

Woah! but didn’t I say python is too slow for a data pipeline? The answer is yes, python is slow, but when it comes to randomly drawing strings and feeding it, it’s snappy enough. The important thing is, that all of the heavy lifting: reading .jpg images from disk, resizing them, batching them, queueing them etc — that’s all done in pure Tensorflow.

Tensorflow Dataset Pipeline

So now we’re left with building a data pipeline in Tensorflow! Without further ado:

So basically we start out with a dictionary of the pythonic generator output (3 strings). Let’s break down what happens here:

  1. tf.Data.Dataset.from_generator() let’s tensorflow know that it’s going to be fed by our pythonic generator. This line doesn’t yet evaluate our pythonic generator at all! It just establishes a plan, that whenever our dataset is hungry from more input, it’s going to grab it from that generator.
    That’s why we need to painstakingly specify the types of the outputs that the generator is going to generate. In our case image1 and image2 are both string to image files, and label is going to be a boolean to indicate whether it’s the same person.
  2. mapoperation: this is where we set up all the tasks necessary to get from the generator input (file names) to what we actually want to feed our model (loaded and resized images). _read_image_and_resize() takes care of that.
  3. batch operation is a convenient function that batches images into bundles with a consistent number of element. This is very useful in training, where we typically want to process multiple inputs at once. Notice that if we start out with say, an image of [128,128,3] dimensions, after the batch we’ll have [10,128,128,3] with 10 being the batch size in this example.
  4. prefetch operation lets Tensorflow do the book-keeping involved in setting up a queue such that the data pipeline continues to read and enqueue data until it has N batches all loaded up and ready to go. In this case I chose 5, and have generally found that numbers 1–5 are usually good enough to utilize GPU capacity to the fullest, without burdening the machine’s memory consumption unnecessarily.

That’s it!

Now that everything is set up, simply instantiating a session and calling session.run(element) will automagically get actual values for img1_resized, img2_resized, and label. If we hit session.run(opt_step) then a new piece of data will flow through the pipeline to perform a single optimization step. Here’s a tiny script for fetching one data element, and performing 100 training steps just to see it all works

When you keep only George Bush and the Dalai Lama as classes, the model converges rather quickly. Here’s the result of this dummy run:

/Users/urimerhav/venv3_6/bin/python /Users/urimerhav/Code/tflow-dataset/recognizer/train.py
{'person1': 'resources/lfw/Dalai_Lama/Dalai_Lama_0002.jpg', 'person2': 'resources/lfw/George_HW_Bush/George_HW_Bush_0006.jpg', 'same_person': False}
{'person1': 'resources/lfw/Dalai_Lama/Dalai_Lama_0002.jpg', 'person2': 'resources/lfw/George_HW_Bush/George_HW_Bush_0011.jpg', 'same_person': False}
step 0 log-loss 6.541984558105469
step 1 log-loss 11.30261516571045
...
step 98 log-loss 0.11421843618154526
step 99 log-loss 0.09954185783863068
Process finished with exit code 0

I hope you’ll find this guide useful. The entire (tiny) code repo is available on github. Feel free to use it however you see fit!

Example #2 — Removing Background From Product Pictures

I’ll probably follow up with a blog more elaborate blog post about it later, but I should note that I recently data training pipeline which I rather like for a small proof of concept site I built called Pro Product Pix as a side project. The concept is simple: E-sellers need to take photos of products, and then have to do a laborious task of removing the background pixels (which is quite tedious!). Why not do that automatically?

But getting labeled images of products with and without background is very tricky. So: data generator to the rescue! Here’s what I ended up doing:

  1. Search images in the public domain where only an object is present (there’s actually a ton of those — transparent png files are all over the place)
  2. Search Bing Images for some generic background templates like “carpet” or “wall”
  3. Build a generator that generates pairs of backgrounds file path + object fiel path, melds them together.
  4. Task the model with guessing which pixels belong to the background and which ones to the object (and a bunch of other things, but that’s another post).

So all told we have a bunch of backgrounds like these:

Bunch of backgrounds

And we merge them with product photos that have no background

Random product photos

There’s a bunch of small problems to solve here. For exmaple, the backgrounds are usually much bigger than the product photo, so we need to crop the background photo to the product size, and then merge the two photos. The entire pipeline is a tad complicated as it wasn’t written with a tutorial in mind, but I’ll put the highlight snippets here. We have a generator that yields background path + product path

We load the background as an rgb image, and the object png actually has both an rgb image and what we call a transparency map (which we call object_pixels: it tells us which pixels belong to the background and which belong to the object). Then inside our dataset.map operation, we crop a random piece of background, and blend it with the foreground.

The results are pretty neat — sometimes. Here’s one before/after pic I rather like, though it’s still not 100% perfect:

I hope I’ll find the time someday to write more about the actual model used. You can play around with the model on the site here, but be forewarned — this basically only works on pictures that resemble real product photos. Anything else, it’s garbage in — garbage out!

Originally posted on Hoss Technology tech blog (that’s my ML consulting company)

--

--

TDS Archive
TDS Archive

Published in TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Uri Merhav
Uri Merhav

Written by Uri Merhav

Cofounder of DocuPanda.io Been doing AI since before it was cool.

Responses (6)