Precision-Recall curve with Keras

TensorBoard is a suite of visualisations for inspecting and understanding your TensorFlow models and runs. They recently released of a “consistent set of APIs that allows developers to add custom visualisation plugins to TensorBoard”. There are already several plugins available.

Although TensorFlow is great, Keras “is a high-level neural networks API, written in Python and capable of running on top of TensorFlow” (and CNTK, or Theano as well), and it has been my preferred choice due to its simplicity. When using the TensorFlow backend, they typically support the TensorBoard callback, to take advantage for its visualisations.

Keras’ TensorBoard callback, however, still do not support all the plugins. I recently wanted to use the Precision-Recall curve plugin (pr_curve) to see how my binary classification problem was doing. I ended up writing an extension of the callback supporting it.

I generally find touching TensorFlow’s graph and ops non-trivial, and this was the case again. Although it is only a partial support (lacks usage of weights, only on validation set, for example), hopefully this will help anyone else in need of similar code, since I’ve found very little material about it around the web.

The Code

Here is the code. You can also see a complete example here.

Note I explicitly decided not to call merge_all and have the PR summary OP on its own. When each training epoch ends, we get the predictions over the validation data and run the PR summary only. All the others regular TensorBoard visualisations still work, since the super calls are there calling the parent class.

Quick Example

For completeness, I wrote an small example using the Breast Cancer Wisconsin (Diagnostic) Data Set from the UC Irvine Machine Learning Repository. Here is the final Precision-Recall curve of my classification model:

PR Curve in TensorBoard using Keras.

And how do we use it? Well, you can find details around the web, but let’s see specifically about the TensorBoard plot.

Tool tip on hover.

When you hover the mouse over any point in the plot, it will show a little box like the one on the left.

It shows the precision and recall of your model (or several models), along with the absolute numbers of the true/false positive/negative estimations, over the given validation set, when using the threshold value. That way we can make a more informed decision on which threshold value to use in practice and which model is better.

This, of course, depends a lot on the problem at hand. In the example case of breast cancer, would you rather have more false-positives (detected cancer, but patient didn’t really have it) or false negatives (the patient did have cancer, but the model failed to predict it). It will always be a trade off.

I hope this was useful, bye!