Higher-Level APIs in TensorFlow

How to use Estimator, Experiment and Dataset to train models

TensorFlow has many libraries, like Keras, TFLearn, and Sonnet, which make it easier to train models rather than use lower-level functionality. While the Keras API is being implemented directly into TensorFlow, TensorFlow is providing some higher-level constructs itself, and some new ones were introduced in the latest 1.3 version.

In this blog, we’ll look at an example using some of these new higher-level constructs, including Estimator, Experiment, and Dataset. It’s also worth noting that you can use Experiment and Dataset on their own. I’ll assume you know the basics of TensorFlow; if not, TensorFlow provides some great tutorials.

Overview of the Experiment, Estimator and DataSet framework and how they interact. (These components will be explained in the following sections)

We’ll be using MNIST as a dataset in this blog. It’s an easy-to-use dataset that is already accessible from TensorFlow. You can find the full code example in this gist. One benefit of using these frameworks is that we don’t have to deal with Graphs and Sessions directly.


The Estimator class represents a model, as well as how this model should be trained and evaluated. We can create an estimator as follows:

To create the Estimator we need to pass in a model function, a collection of parameters and some configuration.

  • The parameters should be a collection of the model’s hyperparameters. This can be a dictionary, but we will represent it in this example as an HParams object, which acts as a namedtuple.
  • The configuration specifies how the training and evaluation are run, and where to store the results. This configuration will be represented by a RunConfig object, which communicates everything the Estimator needs to know about the environment in which the model will be run.
  • The model function is a Python function, which builds the model given the input. (More on this later)

Model function

The model function is a Python function which is passed as a first-class function to the Estimator. We’ll see later that TensorFlow uses first-class functions in other places. The benefit of representing the model as a function is that the model can be recreated over and over by instantiating the function. The model can be recreated during the training with different input, for example, to run validation tests during training.

The model function takes the input features as parameters and the corresponding labels as tensors. It also takes a mode that signals if the model is training, evaluating or performing inference. The last parameter to the model function should be a collection of hyperparameters, which are the same as those passed to the Estimator. This model function should return an EstimatorSpec object which will define the complete model.

The EstimatorSpec takes in the prediction, loss, training and evaluation Operations so it defines the full model graph used for training, evaluation, and inference. Because the EstimatorSpec just takes in regular TensorFlow Operations, we can use frameworks like TF-Slim to define our model.


The Experiment class defines how to train a model and integrates nicely with the Estimator. We can create an experiment as follows:

The Experiment takes as input:

  • An estimator (for example the one we defined above).
  • Train and evaluation data as a first-class function. The same concept as the model function explained earlier is used here. By passing in a function instead of operation, the input graph can be recreated if needed. We’ll talk more about this later.
  • Training and Evaluating hooks. These hooks can be used to save or monitor specific things, or to set up certain operations in the Graph or Session. For example, we will be passing in operations to help initialize the data loaders (again, more later).
  • Various parameters describing how long to train for and when to evaluate.

Once we have defined the experiment, we can run it to train and evaluate the model with learn_runner.run as follows:

Like the model function and the data functions, the learn runner takes in the function that creates the experiment as a parameter.


We’ll be using the Dataset class and the corresponding Iterator to represent our training and evaluation data, and to create data feeders that iterate over the data during training. In this example, we will use the MNIST data that’s available in Tensorflow, and build a Dataset wrapper around it. For example, we will represent the training input data as:

Calling this get_train_inputs will return a first-class function that creates the data loading operations in a TensorFlow graph, together with a Hook to initialize the iterator.

The MNIST data used in this example is initially represented as a Numpy array. We create a placeholder tensor that gets the data fed in; we use a placeholder in order to avoid copying the data. Next, we create a sliced dataset with the help of from_tensor_slices. We will make sure that this dataset runs for an infinite amount of epochs (the experiment can take care of limiting the number of epochs), and that the data gets shuffled and put into batches of the required size.

To iterate over the data we need to create an iterator from the dataset. Because we are using a placeholder we need to initialize the placeholder in the relevant session with the NumPy data. We can do this by creating an initializable iterator. We will create a custom defined IteratorInitializerHook object to initialize the iterator when the graph is created:

The IteratorInitializerHook inherits from SessionRunHook. This hook will call after_create_session as soon as the relevant session is created, and initialize the placeholder with the right data. This hook is returned by our get_train_inputs function and will be passed to the Experiment object upon creation.

The data loading operations returned by the train_inputs function are TensorFlow operations that will return a new batch every time they are evaluated.

Running the code

Now that we have defined everything, we can run the code with the following command:

python mnist_estimator.py --model_dir ./mnist_training --data_dir ./mnist_data

If you don’t pass in parameters, it will use the default flags at the top of the file to figure out where to save the data and the model.

The training will output information like the global step, loss, and accuracy over time on the terminal output. Besides this, the Experiment and Estimator framework will log certain statistics to be visualized by TensorBoard. If we run:

tensorboard --logdir='./mnist_training'

Then we can see all training statistics like the training loss, evaluation accuracy, time per step, and the model graph.

Evaluation accuracy visualised in TensorBoard

I’ve written this blog because I couldn’t find much information and examples on the Tensorflow Estimator Experiment and Dataset framework at the time I wrote the code example. I hope that this blog will give you a brief overview of how these frameworks work, what abstractions they tackle and how to use them.Some notes and other documentation are below if you’re interested in using these frameworks.

Notes on the Estimator, Experiment and Dataset frameworks

Complete example

Inference on the trained model

Once we trained the model we can run estimator.predict to predict the class of a given image. The next code sample illustrates how to do this.