Find the bottleneck of your Keras model using TF trace

Xianbo QIAN
Sep 2, 2018 · 1 min read

Your Keras model is slow in training/inference, but you don’t know why? TF tracing is here to help you.

Note that this will only work in TF 1.11 or later. Before the official release of TF 1.11, please use tf-nightly instead.

First, we need to ask TF session to record trace for the execution and use run_metadata to hold all the tracing output.

run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata= tf.RunMetadata()

Then compile these settings to the model using model.compile method, so that Keras could remember and use them when sess.run is called (during model.fit/predict)

model.compile(..., options=run_options, run_metadata=run_metadata)

That’s it. Let’s give it a try.

model.fit(...)

And extract the output.

from tensorflow.python.client import timeline
tl = timeline.Timeline(run_metadata.step_stats)
ctf = tl.generate_chrome_trace_format()
with open('output/timeline.json', 'w') as f:
f.write(ctf)

Now feel the magic by openning chrome://tracing in Chrome and load the timeline.json file. Enjoy debugging :-)

Tracing output visualized in Chrome

You can find more detailed explanations about TF profiling here.

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade