How to plot the model training in Keras — using custom callback function and using TensorBoard

Kapil Varshney
3 min readAug 6, 2018

--

I started exploring the different ways to visualize the training process while working on the Dog breed identification dataset from Kaggle. I created a basic model that I wanted to test out. I was using Jupyter notebook for quick prototyping. I ran the code and I got the training accuracy, validation accuracy, training loss validation loss. I could see the numbers, but it wasn’t good enough to understand how well the network learnt over time.

I wanted to see a plot which shows me the training metrics and updates it automatically after each epoch. I started exploring the callback features in Keras. I created a custom class called TrainingPlot, then created an object of the TrainingPlot class and passed it to the callback argument while fitting the model using Keras. You can check out the jupyter notebook here.

Live plot in jupyter notebook

I shared the link to the notebook with a friend who works on Computer Vision. He asked me if I can make the same code work with a script and not just a jupyter notebook. This challenge piqued my curiosity. I quickly modified my code to make it work with a script. I wrote two separate python scripts — one for the TrainingPlot class and the other for the training itself.

There is a slight difference in the way the scripts work. Here, you won’t be able to see a live updated plot as you can see in a jupyter notebook. But, you can save the plot (each epoch as a new plot, or rewriting over the previous plot) on the disk. Then you can open up the specific directory and take a look at the plots, while your model trains.

Training Loss and Accuracy plot (when using scripts)

Using TensorBoard

TensorBoard is a visualization tool provided with Tensorflow and can also be used with Keras.

First, you need to install TensorBoard, if you already haven’t:

pip install Tensorboard

Now you don’t need to explicitly define the plot, Tensorboard takes care of it. You can check out the code here:

Before running the train file, in a separate terminal, run the following command:

tensorboard --logdir=logs/

Here logdir should be provided with the same logs directory path that you provided in your training code. You can name it whatever you want, just be consistent.

Use the link Tensorboard returns (http://dl:6006 — in this example) to view the Tensorboard dashboard. The dashboard won’t display the plots until the model training begins. Go back to the other terminal, and run the training script.

If you are using a remote system, there are a few more steps to this. Just going to the link TensorBoard provides won’t work because it always returns a link on the local host. Use the following command on your ‘local’ machine (remember you are running the tensorboard and your model on the remote machine):

ssh -N -f -L localhost:16006:localhost:6006 <user@remote>

(explanation of ssh command:
-N : no remote commands
-f : put ssh in the background
-L <localmachine>:<portA>:<remotemachine2>:<portB> : forward <remotemachine>:<portB> (remote scope) to <localmachine>:<portA> (local scope)

example:
ssh -N -f -L 10.191.2.227:16006:10.191.10.216:6006 kapil@10.191.10.216

Navigate to 10.191.2.227:16006 on the browser in your local machine to view the TensorBoard dashboard.

(To read more on how to run tensorboard on a remote server, check out this stackoverflow thread.)

Link to github repository: https://github.com/kapil-varshney/utilities/tree/master/training_plot

Please leave any comments/feedback/suggestions. I hope this helps.

--

--

Kapil Varshney

Data Scientist (Computer Vision) @ Esri R&D New Delhi. Here to share what I learn and do. Connect with me at https://www.linkedin.com/in/kapilvarshney14/