Keras on TPUs in Colab

A guest post by Sam Witteveen

Did you know that Colab includes the ability to select a free Cloud TPU for training models? That’s right, a whole TPU for you to use all by yourself in a notebook! As of TensorFlow 1.11, you can train Keras models with TPUs.

In this post, let’s take a look at what changes you need to make to your code to be able to train a Keras model on TPUs. Note that some of this may be simplified even further with the release of TensorFlow 2.0 later this year, but I thought it’d be helpful to share these tips in case you’d like to try this out now.

Probably the biggest part of getting your models to work on Cloud TPUs is setting up the right data pipeline. The challenge with building models to run on TPUs is that often the performance bottleneck is no longer in the acceleration of the model, but in the pipeline that feeds data to the model.

We want to make sure that we don’t starve the TPU of data as we are training our models. To do this we use . Let’s quickly go through what we need.

The first thing is we need is an input function that takes our data and slices, shuffles, and batches it. Since graphs are compiled using XLA, we need to specify the shapes of Tensors in advance. This makes it important to ensure that the TPU will get batches of exactly the same size each time. To do this we use the argument ‘drop_remainder = True’ so that any partial batch at the end is discarded rather than presented to the model as a different batch size.

def train_input_fn(batch_size=1024):
# Convert the inputs to a Dataset.
dataset =,y_train))
# Shuffle, repeat, and batch the examples.
dataset = dataset.cache()
dataset = dataset.shuffle(1000, reshuffle_each_iteration=True)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size, drop_remainder=True)
# Return the dataset.
return dataset

Changes to the Model

Once we have a pipeline set up, there are only a few key changes to make to our model to make it compatible with TPUs.

The first of these is to get the address of the TPU so that we can later pass that into the distribution strategy. This can be done by running the following code at the top of our notebook to check if a TPU is attached to the VM, and if so print the address of the TPU.

Getting the TPU address

device_name = os.environ[‘COLAB_TPU_ADDR’]
TPU_ADDRESS = ‘grpc://’ + device_name
print(‘Found TPU at: {}’.format(TPU_ADDRESS))
except KeyError:
print(‘TPU not found’)

Next, we need to choose the optimizer to use for the model. TensorFlow optimizers are currently better supported than Keras optimizers. Here you can see we have chosen to use a TensorFlow Adam optimizer.

# Use a tf optimizer rather than a Keras one for now
opt = tf.train.AdamOptimizer(learning_rate)

Finally, we need to convert our Keras model to a TPU model. Currently, we do this by using the keras_to_tpu function and passing in a distribution strategy. (Eventually, this step will be moved into the compile function and you will just pass a distribution strategy into the compile function of the Keras model.)

tpu_model = tf.contrib.tpu.keras_to_tpu_model(

Once this is done, you should see an output similar to this showing you that the TPU is primed and ready to start training. We can also see the details of the TPU device.

Output from the keras_to_tpu function

Batch Sizes

On GPUs, we always want to make the batch sizes big enough to use as many of the Cuda cores as possible at the same time. Similarly, on TPUs we want to make sure our batch sizes are big enough to take full advantage of the 8 cores and each of their systolic arrays.

The systolic arrays are 128 x 128 in the Cloud TPU v2 that are currently accessible in Colab. Using this information, we can determine that our batch size should be a multiple of 128 for each of the cores. The simplest way to accomplish this is to use a global batch size of 1024 (128 for each of the 8 cores).

Along with changing your batch sizes, you may also want to tune your learning rate to fit the larger batch sizes.


We can now pass in the input function as the dataset and train your TPU model as you normally train a Keras model.
steps_per_epoch = 60,

Going Forward — Distribution Strategies

Over the next few versions of TensorFlow and tf.keras, we expect the introduction of distribution strategies which will make training models with TPUs even easier. With this new API we will be able to just pass in a distribution strategy when compiling the model.


To recap the key things you need to change to run your model on a Cloud TPU are:

  • Find your TPU Address
  • Set up your model and for a fixed Tensor size
  • Convert your Keras model to a TPU Keras model
  • Choose the right batch size
  • Use to feed your model

It’s early days, but this is an exciting way to try out TPUs! For much more info on using TPUs with TensorFlow, please check out the Cloud TPU performance guide and the official TPU examples.

Feel free to play around with an example I wrote in Colab here.