Implementing callbacks in fast.ai

When I first encountered callbacks in fast.ai, I was a bit intimidated. They seemed quite confusing, and they looked more complex than anything that I’d previously learned.

In this post, I’ll explain how callbacks work and how to implement them from scratch in fast.ai.

Note — this post is intended to help you understand how callbacks work. If you just want to implement some new callbacks, you don’t need to follow all of these steps. See the fast.ai documentation for more information.

Recall: A basic training loop

A basic training loop

Here’s what a basic training loop looks like (without any callbacks). For each batch, for each epoch, we follow these five steps:

  • Calculate predictions with our model
  • Calculate the loss of those predictions
  • Calculate the gradients for the model’s parameters
  • Update the parameters
  • Zero out the gradients

What are callbacks and why do we need them?

While the above training loops works perfectly, there are probably going to be times when we want to customize it. There are many different things we might want to do, e.g.:

  • Print out some custom metrics during the training
  • Alter the learning rate throughout the training
  • Switch our training between different models (e.g. in a GAN architecture)

Callbacks allow us to do these types of things. Essentially, a callback is an item that inserts custom code into the training loop at different times.

How can we create callbacks?

There are three main steps to creating callbacks in our training loop:

  1. Create some callback objects
  2. Create a CallbackHandler (an object where we will store our callback objects)
  3. Incorporate the callbacks into our training loop

I’ll now go through each of these steps in detail.

Step 1: Create some callback objects

The first step is to write some callback objects. Here are examples of three simple callback objects:

Example callback objects

Here’s what each callback object does:

  • Every time we begin an epoch, we will call begin_epoch, which will set batch_counter to 1
  • Every time we complete a batch (i.e. every time we update the parameters), we will call after_step. It keeps track of how many batches have been completed for the current epoch, and every 200 batches it will print out the number of batches completed
  • When we begin training, we will call begin_fit, which will set epoch_counter to 1
  • Every time we begin an epoch, we will call begin_epoch, which will keep track of which epoch we are on, and print out the epoch number and the current time
  • Every time we complete an epoch, we will call after_epoch, which will print out the most recent loss calculation

Here’s a diagram summarizing what we have created so far:

Image for post
Image for post

You may have noticed in the above examples that we inherited from a generic Callback class. I explain this in the appendix (see bottom of page).

Step 2: Create a CallbackHandler

Now that we have our callback objects created, we need to create a new class named CallbackHandler:

An example CallbackHandler class (note - this code is truncated, please see appendix for the full code)

This class serves two main purposes:

  • It gives us a place to store all of our individual callbacks. In the above code, the individual callbacks are stored in a list named self.cbs (this happens in the “__init__” function)
  • It allows us to easily call all of our individual callbacks at the same time. For each different point in time in the training loop, we create a function (begin_fit, after_fit, begin_epoch, etc.). When we call one of these functions, it goes through all our individual callbacks (BatchCounter, TimeCheck, PrintLoss) and calls its same-named function. You’ll notice that these functions don’t exist solely to call the callbacks, e.g. begin_epoch also calls learn.model.train(), but that’s not too important to understand for now

Here’s a diagram summarizing what we have created so far:

Image for post
Image for post

Step 3: Incorporate the callbacks into our training loop

Now, all we need to do is edit our training loop so that it uses these callbacks.

We pass our CallbackHandler to the training loop as an input, and then we add some new lines into the training loop that call each of the functions in the CallbackHandler.

Here’s what our updated training loop looks like:

An updated training loop (with callbacks incorporated)

The training loop is very similar to what we had originally, just with some new lines that call the functions in our CallbackHandler.

You may have noticed that there are a lot of if statements in the updated code. This gives us the ability to stop the training at any time we want (or skip specific steps). For example, if we only wanted to train for 20 batches, we could create a callback that returns False after the 20th batch is complete.

And here’s what the output looks like when we run it:

Image for post
Image for post
Output from our callbacks

That’s it, we’re done. Hopefully you now have a better understanding of callbacks!

Thank you to Jeremy Howard from fast.ai — most of the code comes from one of his lessons.

Appendix

  • You may have noticed that when we wrote the callbacks, we inherited from a generic callback class. This is to make sure that every callback object has a function defined for every different point in time (begin_fit, after_fit, begin_epoch, etc.). If these functions aren’t all defined, then we will get an exception when the training loop runs, because it looks for each of these functions
  • Here’s what a generic callback class looks like (notice that most of the functions don’t actually do anything — they exist only to prevent exceptions):
Example of a generic callback class
  • I included a truncated version of the code, because the full code is very long. You can access it at GitHub

Written by

Management consultant & amateur deep learning practitioner

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store