Keras Callbacks Explained In Three Minutes
A gentle introduction to callbacks in Keras. Learn about EarlyStopping, ModelCheckpoint, and other callback functions with code examples.
Building Deep Learning models without callbacks is like driving a car with no functioning brakes — you have little to no control over the whole process that is very likely to result in a disaster. In this article, you will learn how to monitor and improve your Deep Learning models using Keras callbacks like ModelCheckpoint and EarlyStopping.
What are callbacks?
From the Keras documentation:
A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training.
You define and use a callback when you want to automate some tasks after every training/epoch that help you have controls over the training process. This includes stopping training when you reach a certain accuracy/loss score, saving your model as a checkpoint after each successful epoch, adjusting the learning rates over time, and more. Let’s dive deep into some callback functions!
Overfitting is a nightmare for Machine Learning practitioners. One way to avoid overfitting is to terminate the process early. The
EarlyStopping function has various metrics/arguments that you can modify to set up when the training process should stop. Here are some relevant metrics:
- monitor: value being monitored, i.e: val_loss
- min_delta: minimum change in the monitored value. For example, min_delta=1 means that the training process will be stopped if the absolute change of the monitored value is less than 1
- patience: number of epochs with no improvement after which training will be stopped
- restore_best_weights: set this metric to True if you want to keep the best weights once stopped
The code example below will define an EarlyStopping function that tracks the val_loss value, stops the training if there are no changes towards val_loss after 3 epochs, and keeps the best weights once the training stops:
from keras.callbacks import EarlyStoppingearlystop = EarlyStopping(monitor = 'val_loss',
min_delta = 0,
patience = 3,
verbose = 1,
restore_best_weights = True)
This callback saves the model after every epoch. Here are some relevant metrics:
- filepath: the file path you want to save your model in
- monitor: the value being monitored
- save_best_only: set this to True if you do not want to overwrite the latest best model
- mode: auto, min, or max. For example, you set
mode=’min’if the monitored value is
val_lossand you want to minimize it.
from keras.callbacks import ModelCheckpointcheckpoint = ModelCheckpoint(filepath,
from keras.callbacks import LearningRateSchedulerscheduler = LearningRateScheduler(schedule, verbose=0) # schedule is a function
This one is pretty straightforward: it adjusts the learning rate over time using a
schedule that you already write beforehand. This function returns the desired learning rate (output) based on the current epoch (epoch index as input).
Other Callbacks functions
Along with the above functions, there are other callbacks that you might encounter or want to use in your Deep Learning project:
- History and BaseLogger: callbacks that are applied automatically to your model by default
- TensorBoard: This is hands down my favorite Keras callback. This callback writes a log for TensorBoard, which is TensorFlow’s excellent visualization tool. If you have installed TensorFlow with pip, you should be able to launch TensorBoard from the command line:
tensorboard — logdir=/full_path_to_your_logs
- CSVLogger: This callback streams epoch results to a csv file
- LambdaCallback: This callback allows you to build custom callback
In this article, you have learned the main concept of callbacks in Keras and the callback functions. Keras document has a very comprehensive page on callbacks that you should definitely check out: http://keras.io/callbacks/