Multi-GPU training with Estimators, tf.keras and

At Zalando Research, as in most AI research departments, we realize the importance of experimenting and quickly prototyping ideas. With datasets getting bigger it thus becomes useful to know how to train deep learning models quickly and efficiently on the shared resources we have.

TensorFlow’s Estimators API is useful for training models in a distributed environment with multiple GPUs. Here, we’ll present this workflow by training a custom estimator written with tf.keras for the tiny Fashion-MNIST dataset, and then show a more practical use case at the end.

Note: there’s also a cool new feature the TensorFlow team has been working on, (which at the time of writing is still in master), that lets you train a tf.keras model without first needing to convert it to an Estimator, with just a couple lines of additional code! That workflow is great too. Below I’ll focus on the Estimators API. Whichever you choose is up to you!

TL;DR: Essentially what we want to remember is that a tf.keras.Model can be trained with the tf.estimator API by converting it to an tf.estimator.Estimator object via the tf.keras.estimator.model_to_estimator method. Once converted we can apply the machinery that Estimators provides to train on different hardware configurations.

You can download the code for this post from this notebook and run it yourself.

import os
import time
#!pip install -q -U tensorflow-gpu
import tensorflow as tf
import numpy as np

Import the Fashion-MNIST dataset

We will use the Fashion-MNIST dataset, a drop-in replacement of MNIST, which contains thousands of grayscale images of Zalando fashion articles. Getting the training and test data is as simple as:

(train_images, train_labels), (test_images, test_labels) = 

We want to convert the pixel values of these images from a number between 0 and 255 to a number between 0 and 1 and convert the dataset to the [B,H,W,C] format where B is the number of images in a batch, H and W are the height and width and C the number of channels (1 for grayscale) of our dataset:

TRAINING_SIZE = len(train_images)
TEST_SIZE = len(test_images)
train_images = np.asarray(train_images, dtype=np.float32) / 255
# Convert the train images and add channels
train_images = train_images.reshape((TRAINING_SIZE, 28, 28, 1))
test_images = np.asarray(test_images, dtype=np.float32) / 255
# Convert the test images and add channels
test_images = test_images.reshape((TEST_SIZE, 28, 28, 1))

Next, we want to convert the labels from an integer id (e.g., 2 or Pullover), to a one-hot-encoding (e.g., 0,0,1,0,0,0,0,0,0,0). To do so we will use the tf.keras.utils.to_categorical function:

# How many categories we are predicting from (0-9)
train_labels = tf.keras.utils.to_categorical(train_labels, 
test_labels = tf.keras.utils.to_categorical(test_labels,
# Cast the labels to floats, needed later
train_labels = train_labels.astype(np.float32)
test_labels = test_labels.astype(np.float32)

Build a tf.keras model

We will create our neural network using the Keras Functional API. Keras is a high-level API to build and train deep learning models and is user friendly, modular and easy to extend. tf.keras is TensorFlow’s implementation of this API and it supports such things as Eager Execution, pipelines and Estimators.

In terms of the architecture we will use ConvNets. On a very high level ConvNets are stacks of Convolutional layers (Conv2D) and Pooling layers (MaxPooling2D). But most importantly they take for each training example a 3D-tensors of shape ( height, width, channels), which for the case of grayscale images starts with channels=1, and return a 3D-tensor.

Therefore after the ConvNet part we will need to Flatten the tensor and add Dense layers, where the last one returns a vector of size LABEL_DIMENSIONS with the tf.nn.softmax activation:

inputs = tf.keras.Input(shape=(28,28,1))  # Returns a placeholder
x = tf.keras.layers.Conv2D(filters=32, 
kernel_size=(3, 3),
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x)
x = tf.keras.layers.Conv2D(filters=64, 
kernel_size=(3, 3),
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=2)(x)
x = tf.keras.layers.Conv2D(filters=64, 
kernel_size=(3, 3),
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(64, activation=tf.nn.relu)(x)
predictions = tf.keras.layers.Dense(LABEL_DIMENSIONS,

We can now define our model, select the optimizer (we choose one from TensorFlow rather than using one from tf.keras.optimizers) and compile it:

model = tf.keras.Model(inputs=inputs, outputs=predictions)
optimizer = tf.train.AdamOptimizer(learning_rate=0.001)

Create an Estimator

To create an Estimator from the compiled Keras model we call the model_to_estimator method. Note that the initial model state of the Keras model is preserved in the created Estimator.

So what is so good about Estimators? Well to start off with:

  • you can run Estimator based models on a local host or on a distributed multi-GPU environment without changing your model;
  • Estimators simplify sharing implementation between model developers;
  • Estimators build the graph for you, so a bit like Eager Execution, there is no explicit session.

So how do we go about training our simple tf.keras model to use multi-GPUs? We can use the tf.contrib.distribute.MirroredStrategy paradigm which does in-graph replication with synchronous training. See this talk on Distributed TensorFlow training for more information about this strategy.

Essentially each worker GPU has a copy of the network and gets a subset of the data on which it computes the local gradients and then waits for all the workers to finish in a synchronous manner. Then the workers communicate their local gradients to each other via a Ring All-reduce operation which is typically optimized to reduce network bandwidth and increase through-put. Once all the gradients have arrived each worker averages them and updates its parameter and the next step begins. This is ideal in situations where you have multiple GPUs on a single node connected via some high-speed interconnect.

To use this strategy we first create an Estimator from the compiled tf.keras model and give it the MirroredStrategy configuration via the RunConfig config. This configuration by default will use all the GPUs but you can also give it a num_gpus option to use a specific number of GPUs:

strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=NUM_GPUS)
config = tf.estimator.RunConfig(train_distribute=strategy)
estimator = tf.keras.estimator.model_to_estimator(model,

Create an Estimator input function

To pipe data into Estimators we need to define a data importing function which returns a dataset of (images,labels) batches of our data. The function below takes in numpy arrays and returns the dataset via an ETL process.

Note that in the end we also call the prefetch method which will buffer the data to the GPUs while they are training so that the next batch is ready and waiting for the GPUs rather than having the GPUs wait for the data at each iteration. The GPU might still not be fully utilized and to improve this we can use fused versions of the transformation operations like shuffle_and_repeat instead of two separate operations, but I have kept the simple case here.

def input_fn(images, labels, epochs, batch_size):
# Convert the inputs to a Dataset. (E)
ds =, labels))
    # Shuffle, repeat, and batch the examples. (T)
ds = ds.shuffle(SHUFFLE_SIZE).repeat(epochs).batch(batch_size)
ds = ds.prefetch(2)
    # Return the dataset. (L)
return ds

Train the Estimator

Lets first define a SessionRunHook class for recording the times of each iteration of stochastic gradient descent:

class TimeHistory(tf.train.SessionRunHook):
def begin(self):
self.times = []
    def before_run(self, run_context):
self.iter_time_start = time.time()
    def after_run(self, run_context, run_values):
self.times.append(time.time() - self.iter_time_start)

Now the good part! We can call the train function on our Estimator giving it the input_fn we defined (with the batch size and the number of epochs we wish to train for) and a TimeHistory instance via it’s hooks argument:

time_hist = TimeHistory()


Thanks to our timing hook we can now use it to calculate the total time of training as well as the average number of images we train on per second (the average through-put):

total_time = sum(time_hist.times)
print(f"total time with {NUM_GPUS} GPU(s): {total_time} seconds")
avg_time_per_batch = np.mean(time_hist.times)
print(f"{BATCH_SIZE*NUM_GPUS/avg_time_per_batch} images/second with
Fashion-MNIST training through-put and total times on two K80 GPUs with different NUM_GPUS exhibiting poor scaling.

Evaluate the Estimator

In order to check the performance of our model we call the evaluate method on our Estimator:


Retinal OCT (optical coherence tomography) images example

To test the scaling performance on some bigger dataset we use the Retinal OCT images dataset, one of the many great datasets from Kaggle. This dataset consists of cross sectional X-Ray images of the retinas of living humans grouped into four categories: NORMAL, CNV, DME and DRUSEN:

Representative Optical Coherence Tomography Images from Identifying Medical Diagnoses and Treatable Diseases by Image-Based Deep Learning by Kermany et. al

The dataset has a total of 84,495 X-Ray JPEG images, typically 512x496, and can be downloaded via the kaggle CLI:

#!pip install kaggle
#!kaggle datasets download -d paultimothymooney/kermany2018

Once downloaded the training and test set image classes are in their own respective folder so we can define a pattern as:

labels = ['CNV', 'DME', 'DRUSEN', 'NORMAL']
train_folder = os.path.join('OCT2017', 'train', '**', '*.jpeg')
test_folder = os.path.join('OCT2017', 'test', '**', '*.jpeg')

Next we write our Estimator’s input function which takes any file pattern and returns resized images and one hot encoded labels as a This time we follow the best practices from the Input Pipeline Performance Guide. Note in particular that if prefetch's buffer_size is None then TensorFlow will use an optimal prefetch buffer size automatically:

This time to train this model we will use a pretrained VGG16 and retrain just it’s last 5 layers:

keras_vgg16 = tf.keras.applications.VGG16(input_shape=(224,224,3),
output = keras_vgg16.output
output = tf.keras.layers.Flatten()(output)
prediction = tf.keras.layers.Dense(len(labels),
model = tf.keras.Model(inputs=keras_vgg16.input,
for layer in keras_vgg16.layers[:-4]:
layer.trainable = False

Now we have all we need and can proceed as above and train our model in a few minutes using NUM_GPUS GPUs:

strategy = tf.contrib.distribute.MirroredStrategy(num_gpus=NUM_GPUS)
config = tf.estimator.RunConfig(train_distribute=strategy)
estimator = tf.keras.estimator.model_to_estimator(model,

Once trained we can evaluate the accuracy on the test set, which should be around 95% (not bad for an initial baseline 😀):

Retinal OCT training through-put and total times on two K80 GPUs with different NUM_GPUS exhibiting linear scaling.


We showed above how easy it is to train deep learning Keras models on multiple GPUs by using the Estimators API, how one can write an input pipeline which follows the best practices to get good utilisation of our resources (linear scaling) and how to time our training through-put via hooks.

Do note that in the end the main thing we care about is the test set error. You might notice that the test set accuracy decreases as we increase the NUM_GPUS. One reason for this could be the fact that MirroredStrategy effectively trains with a batch size of BATCH_SIZE*NUM_GPUS which might require either adjusting the BATCH_SIZE or the learning rate as we use more GPUs. Here I have kept all the other hyperparameters apart from NUM_GPUS constant for the sake of making the plots, but in reality one would need to tune them.

The size of the dataset as well as the model size also affects how well these schemes scale. GPUs have poor bandwidth when reading or writing small data and this is especially true for older GPUs like the K80 and could account for the Fashion-MNIST plots above.


Thank you to the TensorFlow team especially Josh Gordon and everyone in Zalando Research for their help in fixing up the draft especially Duncan Blythe, Gokhan Yildirim and Sebastian Heinz.