How to use TensorFlow callbacks?

Fabiana Clemente
YData
Published in
6 min readJun 8, 2020

How to model and improve your deep learning model and create your own custom callback

In September of 2019 (miss that year, no?) we saw a much-awaited announcement of the TF2.0 stable version. Google hosted an entire event demonstrating the newly induced features and with even more accessible API for seamless production of Machine Learning applications. In this blog, we’re going to talk about TensorFlow 2.0 Callbacks.

What are callbacks?

Formally,

A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training.

Informally,

they are the instrument cluster of your model personified as a car. You get to know all the information about fuel, speed, tire pressure, etc, etc.. thanks to this instrument cluster. It helps you to drive your car safely and under complete control not resulting in a disaster.
Callbacks are those information clusters + now it contains the feature of self-driving capabilities. You can monitor your loss, EarlyStop your model if hits a dead end and so much more.

Types of Callbacks

In TensorFlow 2.0 there are numerous callbacks at your disposal. Let’s discuss them one-by-one and see their job and how to apply in your code.

BaseLogger:

This is applied to your model definitions by default. You don’t have explicitly invoke a BaseLogger. When you can write history = model.fit(). The history variable is assigned a tf.keras.callbacks.History object. The historyproperty of this object is a dictionary with average accuracy and average loss information for each epoch. You can also inspect the paramsproperty, which is a dictionary of the parameters used to fit the model.

tf.keras.callbacks.BaseLogger()

CSVLogger:

CSVLogger writes a CSV log file containing information about epochs, accuracy, and loss to disk so you can inspect it later. It’s great if you want to roll your own graphs or keep a record of your model training process over time. To invoke and use CSVLogger use the code given below in your existing training file.

import tensorflow as tf
from tf.keras.callbacks import CSVLogger
csv_logger = CSVLogger('training.log')
model.fit(X_train, Y_train, callbacks=[csv_logger])

EarlyStopping:

EarlyStopping stops the training procedure when a monitored metric like loss or accuracy has stopped improving after a certain number of epochs. It helps in reducing unnecessary training of the model when it isn’t improving.

import tensorflow as tf
from tf.keras.callbacks import EarlyStopping
callback = EarlyStopping(monitor='loss', patience=3) # This callback will stop the training when there is no improvement in the validation loss for three consecutive epochs.
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss='mse')
history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
epochs=10, batch_size=1, callbacks=[callback],
verbose=0)
len(history.history['loss']) # Only 4 epochs are run.

ModelCheckpoint:

ModelCheckpoint will save your model as a checkpoint file (in hdf5 format) to disk after each successful epoch or after certain epochs you defined explicitly eg- best seen so far accuracy/loss. Output file names can be dynamically set based on the epoch and along with that, you can add your preferred metric as part of the filename. This callback proves quite useful in instances where your single epoch takes a long time to run (in Deep networks), so you don’t want to loose important training information in case of a system failure. This also proves handy when using AWS spot instances to train the models as you wouldn’t want it to exceed your maximum bid.

import tensorflow as tf
from tf.keras.callbacks import ModelCheckpoint
EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_acc',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)

LearningRateScheduler:

One of the most important things you’ll come across while optimizing the training procedure is setting an optimal value of learning rate which eventually will determine the size of steps taken during the gradient descent process. The issue which arises is setting a constant value of learning rate is the cost/objective function of your network would not be an ideal convex case. So, using a variable learning rate would be optimal.

One method is to start with a relatively large value and decrease it in later training epochs. To set the learning rate you’ve to write a simple function that returns the desired learning rate based on the current epoch and pass it as the schedule parameter to this callback.

import tensorflow as tf
from tf.keras.callbacks import LearningRateScheduler # This function keeps the learning rate at 0.001 for the first ten epochs
# and decreases it exponentially after that.
def scheduler(epoch):
if epoch < 10:
return 0.001
else:
return 0.001 * tf.math.exp(0.1 * (10 - epoch))
callback = LearningRateScheduler(scheduler)
model.fit(data, labels, epochs=100, callbacks=[callback],
validation_data=(val_data, val_labels))

RemoteMonitor

RemoteMonitor is another useful callback which is used during training. This callback is used to stream events to a server. Events include training report, verbose procedures, and other information. It requires the requests library. Events are sent to root + '/publish/epoch/end/' by default. Calls are HTTP POST, with a data argument which is a JSON-encoded dictionary of event data. If send_as_json is set to True, the content type of the request will be application/JSON. Otherwise, the serialized JSON will be sent within a form.

import tensorflow as tf
from tf.keras.callbacks import RemoteMonitor
remote_monitor = RemoteMonitor(
root='http://localhost:9000', path='/publish/epoch/end/', field='data',
headers=None, send_as_json=False
)
model.fit(data, labels, epochs=100, callbacks=[remote_monitor],
validation_data=(val_data, val_labels))

Tensorboard

If you’ve worked with TensorFlow you probably know how cool Tensorboard is. This particular tool has made developers’ life so much easier by visualizing what’s happening during training. The Tensorboard callback is one of the ways of enabling these features for your model. By using a TensorBoard callback, logs will be written to a directory that you can then examine with the TensorBoard visualization tool.

import tensorflow as tf
from tf.keras.callbacks import Tensorboard
tensorboard_callback = TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

Run this command to start the Tensorboard on your system:

# run the tensorboard command to view the visualizations.
tensorboard --logdir=path_to_your_logs

ReduceLROnPlateau

This callback reduces the learning rate when a metric you’ve mentioned during training eg. accuracy or loss has stopped improving. Models often benefit from reducing the learning rate by a factor of 2–10 once learning stagnates. This callback monitors a quantity and if no improvement is seen for a ‘patience’ number of epochs, the learning rate is reduced.

import tensorflow as tf
from tf.keras.callbacks import ReduceLROnPlateau
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
patience=5, min_lr=0.001)
model.fit(X_train, Y_train, callbacks=[reduce_lr])

LambdaCallback

LamdaCallback is sort of a provision to create your own custom callback if none of the ones specified meet your needs. It is used for creating simple, custom callbacks on-the-fly.

import tensorflow as tf
from tf.keras.callbacks import LambdaCallback

Few examples in which you can achieve some functionality using LambdaCallback-

# Print the batch number at the beginning of every batch.
batch_print_callback = LambdaCallback(
on_batch_begin=lambda batch,logs: print(batch))

Stream the epoch loss to a file in JSON format. The file content
is not well-formed JSON but rather has a JSON object per line.

import json
json_log = open('loss_log.json', mode='wt', buffering=1)
json_logging_callback = LambdaCallback(
on_epoch_end=lambda epoch, logs: json_log.write(
json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
on_train_end=lambda logs: json_log.close()
)

Terminate some processes after having finished model training.

processes = ...  #complete this before running the script
cleanup_callback = LambdaCallback(
on_train_end=lambda logs: [
p.terminate() for p in processes if p.is_alive()])
model.fit(...,
callbacks=[batch_print_callback,
json_logging_callback,
cleanup_callback])

You can also achieve this by inheriting from Keras custom call callbacks. A custom callback is a powerful tool to customize the behaviour of a Keras model during training, evaluation, or inference, including reading/changing the Keras model.
Let’s build a vanilla custom callback. We’ll start with importing TensorFlow and define a simple Sequential Keras model:

Then, load the MNIST data for training and testing from Keras datasets API:

Now, define a simple custom callback to track the start and end of every batch of data. During those calls, it prints the index of the current batch.

Providing a callback to model methods such as tf.keras.Model.fit() ensures the methods are called at those stages:

There you go! You just created your own Custom Callback Method in TensorFlow.

You’ve reached the end!

Congratulations! you just learned about one of the most significant features of TensorFlow which will not only help you build amazing products but will also help you as keep a close eye on what’s happening inside your model training.

Fabiana Clemente is Chief Data Officer at YData.

Making data available with privacy by design.

YData helps data science teams deliver ML models, simplifying data acquisition, so data scientists can focus their time on things that matter.

--

--

Fabiana Clemente
YData
Editor for

Passionate for data. Thriving for the development of data privacy solutions while unlocking new data sources for data scientists at @YData