Effective TensorFlow 2.0: Best Practices and What’s Changed

Posted by the TensorFlow Team

In a recent article, we mentioned that TensorFlow 2.0 has been redesigned with a focus on developer productivity, simplicity, and ease of use.

To take a closer look at what’s changed, and to learn about best practices, check out the new Effective TensorFlow 2.0 guide (published on GitHub). This article provides a quick summary of the content you’ll find there. If any of these topics interest you, head to the guide to learn more!

A brief summary of major changes

There are many changes in TensorFlow 2.0 to make users more productive, including removing redundant APIs, making APIs more consistent (Unified RNNs, Unified Optimizers), and better integrating with the Python runtime with Eager execution.

Many RFCs (check them out, if you’re new to them!) have explained the changes and thinking that have gone into making TensorFlow 2.0. This guide presents a vision for what development in TensorFlow 2.0 should look like. It’s assumed you have some familiarity with TensorFlow 1.x.

API Cleanup

Many APIs are either gone or moved in TF 2.0, and some have been replaced with their 2.0 equivalents — tf.summary, tf.keras.metrics, and tf.keras.optimizers. The easiest way to automatically apply these renames is to use the v2 upgrade script.

Eager execution

TensorFlow 1.X requires users to manually stitch together an abstract syntax tree (the graph) by making tf.* API calls. It then requires users to manually compile the abstract syntax tree by passing a set of output tensors and input tensors to a session.run() call. By contrast, TensorFlow 2.0 executes eagerly (like Python normally does) and in 2.0, graphs and sessions should feel like implementation details.

No more globals

TensorFlow 1.X relied heavily on implicitly global namespaces. When you called tf.Variable(), it would be put into the default graph, and it would remain there, even if you lost track of the Python variable pointing to it. You could then recover that tf.Variable, but only if you knew the name that it had been created with. This was difficult to do if you were not in control of the variable’s creation. As a result, all sorts of mechanisms proliferated to attempt to help users find their variables again.

TensorFlow 2.0 eliminates all of these mechanisms (Variables 2.0 RFC) in favor of the default mechanism: Keep track of your variables! If you lose track of a tf.Variable, it gets garbage collected. See the guide for more details.

Functions, not sessions

A session.run() call is almost like a function call: You specify the inputs and the function to be called, and you get back a set of outputs. In TensorFlow 2.0, you can decorate a Python function using tf.function() to mark it for JIT compilation so that TensorFlow runs it as a single graph (Functions 2.0 RFC).

This mechanism allows TensorFlow 2.0 to gain all of the benefits of graph mode:

  • Performance: The function can be optimized (node pruning, kernel fusion, etc.)
  • Portability: The function can be exported/reimported (SavedModel 2.0 RFC), allowing users to reuse and share modular TensorFlow functions.

With the power to freely intersperse Python and TensorFlow code, you can take full advantage of Python’s expressiveness. But portable TensorFlow executes in contexts without a Python interpreter — mobile, C++, and JS. To help users avoid having to rewrite their code when adding @tf.function, AutoGraph will convert a subset of Python constructs into their TensorFlow equivalents.

See the guide for more details.

Recommendations for idiomatic TensorFlow 2.0

Refactor your code into smaller functions

A common usage pattern in TensorFlow 1.X was the “kitchen sink” strategy, where the union of all possible computations was preemptively laid out, and then selected tensors were evaluated via session.run(). In TensorFlow 2.0, users should refactor their code into smaller functions which are called as needed. In general, it’s not necessary to decorate each of these smaller functions with tf.function; only use tf.function to decorate high-level computations — for example, one step of training, or the forward pass of your model.

Use Keras layers and models to manage variables

Keras models and layers offer the convenient variables and trainable_variables properties, which recursively gather up all dependent variables. This makes it easy to manage variables locally to where they are being used.

Keras layers/models inherit from tf.train.Checkpointable and are integrated with @tf.function, which makes it possible to directly checkpoint or export SavedModels from Keras objects. You do not necessarily have to use Keras’s.fit() API to take advantage of these integrations.

See the guide for more details.

Combine tf.data.Datasets and @tf.function

When iterating over training data that fits in memory, feel free to use regular Python iteration. Otherwise, tf.data.Dataset is the best way to stream training data from disk. Datasets are iterables (not iterators), and work just like other Python iterables in Eager mode. You can fully utilize dataset async prefetching/streaming features by wrapping your code in tf.function(), which replaces Python iteration with the equivalent graph operations using AutoGraph.

@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
prediction = model(x)
loss = loss_fn(prediction, y)
gradients = tape.gradients(loss, model.trainable_variables)
optimizer.apply_gradients(gradients, model.trainable_variables)

If you use the Keras .fit() API, you won’t have to worry about dataset iteration.

model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)

Take advantage of AutoGraph with Python control flow

AutoGraph provides a way to convert data-dependent control flow into graph-mode equivalents like tf.cond and tf.while_loop.

One common place where data-dependent control flow appears is in sequence models. tf.keras.layers.RNN wraps an RNN cell, allowing you to either statically or dynamically unroll the recurrence. For demonstration’s sake, you could reimplement dynamic unroll as follows:

class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
outputs = tf.TensorArray(tf.float32, input_data.shape[0])
state = self.cell.zero_state(input_data.shape[1], dtype=tf.float32)
for i in tf.range(input_data.shape[0]):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state

See the guide for more details.

Use tf.metrics to aggregate data and tf.summary to log it

Finally, a complete set of tf.summary symbols are coming soon. You can access the 2.0 version of tf.summary with:

from tensorflow.python.ops import summary_ops_v2

See the guide for more details.

Next steps

This article provided a quick summary of the Effective TF 2.0 Guide (if you’re interested in these topics, head there to learn more!) To learn more about TensorFlow 2.0, we also recommend these recent articles:

And please tune in for the TensorFlow developer summit on March 6th and 7th. As always, all the talks will be uploaded to YouTube for folks who can’t make it in person.