Understanding FastAI v2 Training with a Computer Vision Example- Part 3: FastAI Learner and Callbacks

Rakesh Sukumar
Analytics Vidhya
Published in
12 min readOct 20, 2020
Image Source: https://images.app.goo.gl/juHhyxgoGhUPRNY38

This is my third article in this series. This series is aimed at those who are already familiar with FastAI and want to dig a little deeper and understand what is happening behind the scene. In this article, we will use the resnet model built in the first article to understand FastAI Learner & Callbacks. The overall structure of this series is as below:

  1. Study the resnet34 model architecture and build it using plain Python & PyTorch.
  2. Deep dive into FastAI optimizers & implement a NAdam optimizer.
  3. Study FastAI Learner and Callbacks & implement a learning rate finder (lr_find method) with callbacks.

We will use Google Colab to run our code. You can find the code files for this article here. Let’s get started..!

First, we will quickly recreate our model from the first article.

Let’s create a learner using our model.

Let’s understand the arguments to the Leaner constructor:

  • dls: The dataloaders object containing training & validation data.
  • model: The neural network model.
  • loss_func: The loss function to be used. If no loss function is specified, the default loss function for the data loaders object is used. The dataloader object selects an appropriate loss function based on the type of the target. Here, since we have used a CategoryBlock (see dblock creation above) for our target variable, CrossEntropyLoss is added as the default loss function. FastAI has a good tutorial on creating custom the loss functions here.
  • opt_func: Function(Callable) used to create the optimizer object. By default an Adam optimizer is added.
  • lr: Default learning rate.
  • splitter: A function that takes the model as input and returns a list of parameter groups. The default function ‘trainable_param()’ returns all trainable parameters of the model (parameters for which requires_grad = True) as a single list creating a single parameter group.
  • cbs: A list of callbacks to customize the training loop. By default, FastAI adds the following callbacks to the learner: TrainEvalCallback, Recorder and ProgressCallback. We will study learner callbacks in details in this article.
  • metrics: An optional list of metrics, that can be either functions or an object of FastAI Metric class. No metrics are selected by default.
  • path & model_dir: path and model_dir are used to save and/or load models.
  • wd: Default weight decay used.
  • moms: Default momentum used in learn.fit_one_cycle() method.
  • wd_bn_bias: Controls if weight decay is applied to batchnorm type layers & bias parameters. Default behavior is not to apply wd for these layers.
  • train_bn: Controls if batchnorm type layers are to be trained even if they belong to the frozen part of the model. The default behavior is to train batchnorm type layers.

FastAI Training Loop

We will use the learn.fit() method to study the FastAI training loop. Let’s train our model for 1 epoch for a reference.

Let’s look at the arguments for learn.fit() method.

The learn.fit() method accepts the following arguments:

  • n_epoch: the number of epochs for which the model must be trained
  • lr & wd: the learning rate & weight decay to be used by the optimizer
  • cbs: Any additional callback to be added to the learner for the fit operations. Note that this in addition to the callbacks added with the learner constructor. These callbacks will be removed from the learner object after the fit() operation.
  • reset_opt: A Boolean flag to indicate if the optimizer associated with learner object needs to be reset.

A fit operations (learn.fit() method call) consists of n_epoch epochs through the dataset. An epoch consists of one full iteration through the training dataset followed by a full iteration through the validation dataset. FastAI uses events & exceptions for control flow. There are five event types & associated exceptions (the exceptions are provided in brackets) in a FastAI fit operation:

  1. fit (CancelFitException)
  2. epoch (CancelEpochException)
  3. train (CancelTrainException)
  4. validate (CancelValidException)
  5. batch (CancelBatchException)

These event types have the following events associated with them:

  • before_{event_type}
  • after_cancel_{event_type}
  • after_{event_type}.

The control flow for these event types is as follows (see the figure below for a better understanding):

  1. before_{event_type} event. I will explain the control flow for an event later.
  2. Call a function/method associated with the event type.
  3. after_cancel_{event_type} event, if the associated exception is raised during step 1. Note that only step 2 is skipped if the exception occurs during step 1.
  4. after_{event_type} event. Step 4 & 5 are always run, even if the exception is raised during step 1.
  5. An optional ‘final’ function can be run at the end of the event type. In learn.fit() method call, only the ‘fit’ event type has a ‘final’ function (self._end_cleanup()) associated with it(see the below figure).

In addition to above events, FastAI also has the following events that are encountered for every batch during a training/validation/inference run.

  • after_pred
  • after_loss
  • before_backward (only for training run)
  • after_backward (only for training run)
  • after_step (only for training run)

An event calls the learner object passing in the event name (example: ‘after_pred’ event calls learn(‘after_pred’)). I will explain the control flow for an event in details later. First let’s study the fit operation. The control flow during the fit() operation is best explained with the below figure. Note that if no exception occurs during the before_{event type} event, the associated method is executed and the control passes to after_{event type} event. However, if a Cancel{event type}Exception is raised during before_{event type}, the associated method is skipped & the control passes to after_cancel_{event type} event.

Control flow during a FastAI fit operation. Note that if no exception occurs during the before_{event type} event, the associated method is executed and the control passes to after_{event type} event. However, if a Cancel{event type}Exception is raised during before_{event type} event, the associated method is skipped & the control passes to after_cancel_{event type} event.

I will briefly explain what each of these methods/events does. Note that ‘learn’/’self’ below refers to the learner object .

  1. learn.fit(): The learn.fit() method creates a context with additional callbacks (passed to the fit method) added to the learner object. This ensures that the additional callbacks are removed from the learner after the execution of the fit() method. Next, it creates a new optimizer object if an optimizer object is not already associated with the learner object or if reset_opt argument is True. It then stores n_epoch as an attribute of the learner object and initializes loss value learn.loss with torch.Tensor([0.]). Next, it initiates the control flow for ‘fit’ event type starting with the ‘before_fit’ event.
  2. before_fit: before_fit event calls learn(‘before_fit’)
  3. learn._do_fit(): Initiates a for loop to train the model for learn.n_epoch epochs. For each iteration, stores the current epoch number as learn.epoch and initiates the control flow for ‘epoch’ event type starting with ‘before_epoch’ event.
  4. before_epoch: before_epoch event calls learn(‘before_epoch’)
  5. self._do_epoch(): First calls self._do_epoch_train() to run 1 epoch on training dataset then calls self._do_epoch_validate() to run 1 epoch (compute loss & metrics) on validation dataset.
  6. self._do_epoch_train(): Sets training dataloader as self.dl and initiates the control flow for ‘train’ event type starting with ‘before_train’ event. After the execution completes the ‘train’ event type, self._do_epoch_validate() is called which sets the validation dataloader as self.dl and initiates the control flow for ‘validate’ event type with torch.no_grad() so that gradient calculations are disabled for the validation run. Note that self._do_epoch_validate() method is also used at inference time in learn.predict() call to get predictions for new samples.
  7. before_train: before_train calls learn(‘before_train’).
  8. self.all_batches(): Sets self.n_iter to the length (i.e. no of batches) of self.dl. Note that self.dl is set to training data loader during the training run (event type: ‘train’) and validation dataloader during the validation run (event type: ‘validate’). Next, it calls self.one_batch(i, b) on each i, and batch ‘b’ in enumerate(self.dl).
  9. self.one_batch(i, b): Sets ‘i’ as self.iter and splits the batch ‘b’ as self.xb - the input and self.yb - the target. Then, it initiates the control flow for ‘batch’ event type starting with ‘before_batch’ event.
  10. before_batch: before_batch call learn(‘before_batch’)
  11. self._do_one_batch(): self._do_one_batch() does the following:

(i) Get’s the prediction for the batch and stores it to self.pred

(ii) after_pred: Calls learn(‘after_pred’)

(iii) Computes loss by calling self.loss_func() & stores it to self.loss

(iv) after_loss: Calls learn(‘after_loss’).

(v) before_backward: Calls learn(‘before_backward’). (v) to (ix) are executed only for the training run.

(vi)self._backward(): Computes gradients by calling self.loss.backward()

(vii)after_backward: Calls learn(‘after_backward’)

(vii)self._step(): Calls self.opt.step() to update the parameters.

(viii)after_step: Calls learn(‘after_step’)

(ix) self.opt.zero_grad(): Zeroes the gradients.

12. self._end_cleanup(): Called at the end of the fit event type to reset self.xb, self.yb, self.dl, self.pred and self.loss.

Let’s look at the control flow during an event, when learn(‘event-name’) is called.

Control for flow during an Event

FastAI does the following when learn(‘event-name’) is called:

  1. Get all callbacks associated with the learner object.
  2. Sort the callbacks in the correct order. (I ‘ll explain the sorting logic shortly). You can check the sorted order of callbacks using sort_by_run() function.
  3. For each callback (let’s say cb) in the sorted list, run the method with the same name as ‘event-name’ (i.e. cb.‘event-name’) if the callback has a such a method and if it satisfies any of the following 3 conditions:

(i) (cb.run = True) AND (event-name is one of the *inner_events)

(ii) (cb.run = True) AND (cb.run_train = True) AND (learn.training = True, if learn has a “training” attribute.)

(iii) (cb.run = True) AND (cb.run_valid = True) AND (learn.training = False, if learn has a “training” attribute)

inner_events: before_batch, after_pred, after_loss, before_backward, after_backward, after_step, after_cancel_batch, and after_batch.

The callback attributes cb.run, cb.run_valid and cb.run_train can be used to enable/diable the callback and they are all set to True by default. As the name indicates, the attributes ‘run_train’ and ‘run_valid’ can be used to run the callback selectively during training or validation/inference respectively. Also, note that cb.run attribute is reset to True at the end of a fit operation during the ‘after_fit’ event.

Callback sorting logic:

Sorted order of callbacks
  1. Get the list of callbacks, lets call it ‘cbs’. Create an empty list, say ‘res’, as a placeholder to stored the sorted callbacks.
  2. Move all callbacks with ‘toward_end’ attribut = True to the end of cbs.
  3. For each callback ‘cb’ in cbs, starting from the first one, check if cb has a ‘run_after’ attribute, if yes, check if cb.run_after is equal to any other callback in cbs.
  4. For all other callbacks ‘o’ in cbs, check if o has a ‘run_before’ attribute, if yes, check if o.run_before is equal to cb.
  5. If the answer to both 3 and 4 are False, add (append) cb to the sorted list ‘res’. Repeat 3 to 5 until all callbacks are moved to res.

We can use the learn.show_training_loop() method to see the events & callbacks during a fit operation.

Now, let’s try to implement our own LRFinder callback. As usual I have copied most of the code below from FastAI’s github repo.

LR_Finder Callback

We want the lr finder callback to do the following:

  • The lr finder should not disturb the weights initialization, hence the lr finder callback should first save the model parameters to a local file before it starts and should reload it back at end.
  • The lr finder should run the fit for 100 iterations (batches) max using exponentially increasing learning rates, hence it should find the appropriate number of epochs to run the fit.
  • Set up an exponentially increasing learning rate schedule.
  • Cancel the training if the number of iterations > 100 or if the loss increases above 4 times the best loss obtained.
  • Zero the gradients at the end of fit operation.
  • Create a plot of loss vs learning rate after the fit operation.

Next, we will add a custom method to the learner to run our callback. @patch is a FastAI decorator to add a method to the type-annotation class of the first argument.

Let’s use our lr_find method now. But first we will re-initialize the model weights.

We will now try to use our learning rate finder in a transfer learning setting. We will use the Pets dataset and download FastAI’s xresnet model pretrained on Imagenet data. I decided not use our model trained on Imagenette data as we only had a single dog breed in our target class (for Imagenette data).

First let’s build the dataloaders for Pets dataset.

We have 37 dog and cat breeds in Pets dataset.

Let’s download the pretrained xresnet34 model from FastAI.

By default, FastAI cuts a pretrained model at the pooling layer and adds a custom head according to the number of target classes. See the documentation for Vision Learner here for details. Let’s replicate the same for our study. First we will create a ‘head’ for our model.

We will initialize the parameters of the head before joining it with the pretrained part of the model as we do not want to modify the pretrained parameters.

Note: I am only showing the top & the bottom part of the output here..

We will use differential learning rates & train the head of our model at 100x lr compared to the body. Hence, we will use a custom splitter function to create two parameters groups for our optimizer, one containing the parameters for the head and the other containing the rest. Note that our model head had just 6 parameters .

Next, we will freeze the pretrained part of the model & just train the parameters of the head. We will still train the parameters of batchnorm layers in the frozen part. Research has shown that training the batchnorm layers in a transfer learning setting generally gives better results. This behavior is controlled by the “train_bn” (default: True) argument to the Learner constructor.

Let’s check the trainable parameters in the parameter group. We can see that all parameters in the first parameter group belong to batchnorm layers (compare the parameter names with the display of pet3_resnet34 architecture above).

Let’s find a suitable learning for our model under the constraint that the learning rate for the head should be 100 time the learning rate for the rest of the model (body). This can be achieved by passing a numpy array to start_lr & end_lr arguments of the learn.lr_find() method. However, FastAI only tracks the learning rate for the last parameter group in the Recorder callback. I will slight modify a few methods of the Recorder callback below so that it tracks the learning rate for both parameter groups and also plots them both. I will use @patch decorator to make the changes in-place.

Let’s use our own lr_find to find an appropriate learning rate:

We will train the model for 3 epochs.

Next, unfreeze the entire model and train again. We will use my_lr_find() again to find an appropriate learning rate.

--

--