How to write a custom Estimator model for the Cloud TPU

By Lak Lakshmanan (@lak_gcp), Technical Lead, Google Cloud Platform

Tensor Processing Units (TPUs) accelerate a wide range of machine learning workloads within Google and are available to Google Cloud customers. You can find TPU-enabled versions of state-of-the-art image models like ResNet and AmoebaNet in the Cloud TPU reference model repository; text summarization and question answering tasks can be performed on TPUs using the powerful Tensor2Tensor library. These tutorials walk you step-by-step through using many of the most popular Cloud TPU reference models.

But what if you have a custom TensorFlow model? In this article, I will step you through the process of writing a custom Estimator to run on the Cloud TPU. Along the way, I will point out gotchas to watch for and best practices to follow. The full code for the solution is on GitHub; I will show only pertinent snippets here.

A custom TensorFlow Estimator consists of a base Estimator that is passed in a model function:

The model function takes in features, labels and mode and returns an EstimatorSpec. For example, the model function for an image classification problem might consist of

The tf.contrib.tpu package in TensorFlow provides wrapper classes to help you write the code in such a way that you can run the code on CPU, GPUs, and Cloud TPUs. Let’s walk through the process of writing a custom Estimator in this accelerator-agnostic way.

1. Convert inputs to TF Records

Cloud TPUs are so fast that if you aren’t careful, your training will be dominated by reading and writing data (“infeed” and “outfeed”) and by the saving of checkpoints. Because it is wasteful to have TPUs wait on input/output, we will do several things to maximize the amount of time that the TPU spends on computation.

The first of these is to avoid parsing and data wrangling in the input function to the Estimator. Instead, transform the data beforehand into TF Records. TF Records are easier to batch than individual image files, and because the labels are in the record itself, this cuts down on the number of small files that have to be read. I used Apache Beam to carry out this transformation — you can find a script to read JPEGs and write out TF Records in the official TPU repository. The Apache Beam program can be executed at scale on Cloud Dataflow, but if your data source is not currently on Google Cloud, you can simply execute the program locally on a large VM (make sure to pip install apache-beam).

TF Records are dictionaries. For image classification, there are two entries written by the above pipeline that are important: ‘image/class/label’ which is an int64 and ‘image/encoded’ which consists of the content of the JPEG files.

2. Write an input function to read TF Records

As with any Estimator, you will need to write an input function to read in these TF Records. This task is considerably simplified when using the Dataset API, but there are a few things to keep in mind. I’ll point them out as we go along.

Here’s my input function:

Note that the input function takes a parameter — the params. In practice, this will be the command-line parameters passed to your training program so that we can extract details about the dataset such as the number of training and evaluation images.

The batch_size is special — because TPUs have multiple cores, the batch_size is set by the TPU Estimator and is the effective batch size. You have to return exactly batch_size records — you can not send back a partially filled batch. This is not a problem during training, since you will be looping over the training data indefinitely. However, it means that it’s simplest to round down the evaluation dataset to a multiple of the number of cores. If the number of cores is 8 and if you have 1026 images in your evaluation set, you will only use the first 1024 of them for evaluation. The remaining 2 will be dropped. (There are ways to process the final partial batch on Cloud TPU as well, but I won’t cover that detail here.)

As with any distributed training, you should ensure that each worker sees a different subset of the data — this is handled by the parallel interleaving of all your files and shuffling of records within the buffer itself.

A common need for image classification is to augment your original data by adding random crops, flips, etc. That is done by my read_and_preprocess function. Note that I apply this function to each TF Record and create 8 parallel batches, dropping any remaining records (again, this has no effect during training, since you repeat indefinitely).

The next part is of transposing. It turns out that, on TPUs, transposing the data to have the batch size last greatly improves the performance. So, we do that if necessary. The transpose_input flag will be false if we are running on a GPU or CPU.

TPUs require statically sized tensors. Although we have ensured that this is the case (by dropping the remainder), the Dataset API is written for core TensorFlow, which is more general. So, we call a function that changes the batch_size in the shape from None to, well, the batch_size.

The final bit of optimization is important. We need to prefetch the data. In other words, while the TPU is crunching one batch of records, we have the I/O threads go out and fetch the next batch. This keeps the TPU (or GPU) at maximum utilization. There is no impact on a CPU.

3. Processing TF Records

The input function (above) sets up how the input is handled, but defers the actual parsing to a method that I called read_and_preprocess(). Here’s how that looks like:

There are two key things to note here. One is the use of parse_single_example — because this function is called from a map(), it will be called on a single TF Record. We pull out the pertinent information (encoded image and label) from the record and use them to construct the necessary tensors. The second thing to note is that the data have to be numeric. I can not, for example, send back the label string because TPUs handle only numeric data. It is necessary to have computed the index of the label in the preprocessing pipeline so that the label, at this point, is simply an integer.

4. Serving input function

After you train the model, you will want to deploy the model and serve it with TF Serving. The code here is the same as you would have with any Estimator:

The TPU is optimized for batch inference — if your use case requires online prediction, you are currently better off serving from a CPU or a GPU (depending on the size and complexity of your model). The way I have written the input function, I am assuming that I am sent only one image, so this is really meant for CPU/GPU serving.

5. Model function

The model function needs to create and return a TPUEstimatorSpec. Here’s the implementation:

The features that are passed in might either be the image (my training and evaluation input functions) or a dictionary (my serving input function). I check, and retrieve the image from the features.

Then, I invoke my actual model math on the image. This should be familiar TensorFlow code that uses tf.layers. Browse the full source code to see how this looks.

Because this is a classification problem, I compute the integer label and the string label based on the logits for each of the classes using softmax followed by argmax and gather. I compute the cross entropy loss. This is like any Estimator.

The one difference is that while a regular Estimator requires the evaluation metrics as a dictionary, the TPUEstimator asks for a function that can be invoked either on the controlling CPU or on the TPU. Hence, the way you specify the eval metrics is a bit different.

The optimizer that you use has to be wrapped in a CrossShardOptimizer if you are using a TPU. This distributes the optimization across the cores.

The training operation is the minimization of this cross-shard-optimized loss. Use optimizer.minimize() and not layers.optimize_loss().

Put all these together and return a TPU Estimator Spec.

6. Train and evaluate loop

You may be familiar with Estimator’s train_and_evaluate loop. Unfortunately, it does not (yet) work effectively with TPUs. Fortunately, it is not too difficult to roll your own to gain more control in terms of how often and what you checkpoint (recall that you want to minimize the context switching and I/O overhead associated with overly frequent checkpointing).

The first thing is to pull out some of the command-line parameters and use them to specify the maximum number of steps and the batch sizes of training and evaluation.

The next bit is to find the TPU. If you created a Cloud TPU yourself on Google Compute Engine, you would have given it a name. I’m assuming that this name is passed in as the command-line parameter named ‘tpu’. If you are using Cloud ML Engine, the TPU name, zone, etc. are automatically inferred. Make sure to do this only if the use_tpu flag is set. If the user is running on a CPU or GPU, just create an empty RunConfig.

Next, create a TPUEstimator with model function, config, parameters and batch sizes. With the estimator created, we can move on to the actual training and evaluation loop:

The way that TensorFlow Estimators work is that they do a warm start from previously existing checkpoints. We can replicate that by loading in the checkpoint found in the output directory. Then, until we reach the maximum number of steps specified, we step through the training data train_batch_size steps at a time.

In my case, I am evaluating at every checkpoint on the full evaluation set, but obviously, you can make this less computationally intensive.

7. Export model for serving

Finally, once the training is complete, I export a saved model. The saved model can be deployed for prediction using TF Serving or Cloud ML Engine.

At this point, we have a custom Estimator model that can be trained on Cloud TPUs. We wrote it in such a way (honoring the use_tpu flag and making the transpose optional, for example) that the same code also supports a variety of hardware, including CPUs and GPUs — so we actually have an Estimator model that works on all three types of hardware.

Next steps:

  1. Try out the code that accompanies this article by downloading it from GitHub.
  2. Learn how to run ResNet on TPUs on your own data (without having to write any of the code) by running a codelab

Take the Machine Learning with TensorFlow specialization on Coursera — it steps you through TensorFlow concepts and how to train, tune and deploy ML models at scale on Google Cloud.