Self-Organizing Maps with Fast.ai — Step 2: Training the SOM Module with a Fast.ai Learner
This is the second part of the Self-Organizing Maps with fast.ai article series.
All the code has been published in this repository and this PyPi library.
Overview
In case you’re unfamiliar with how fast.ai is organized, here’s a quick intro to get you up to speed.
Fast.ai defines a class named Learner
that is in charge of training our models. The learning process can then be customized by using Callback
s that the Learner will invoke during the training cycle.
In the previous article, we implemented a fast Self-Organizing Map using PyTorch. Now, we will refactor the model from step 1 into a Learner
subclass. We want to achieve something like this:
Creating the Learner
We can start by creating a SomLearner
class. Fast.ai Learners accept a DataBunch
item as a constructor parameter, so let’s do the same:
Notice that we are also receiving a list of Callbacks inside our constructor.
Passing data to the SomLearner
The next step is creating a DataBunch
that our SomLearner
can use to train the model. For this purpouse, we will define another class that will take care of doing the required transformations to turn a torch.Tensor
into a DataLoader
.
I won’t go over the details here, since we will have the chance to familiarize with DataLoaders and other concepts in a dedicated article in the series.
Now that we have our DataBunch
, we can try and train our model:
Unfortunately, this will not work. We are missing one key component in the training loop: a loss function.
Adding a loss function
As we said before, the SOM model has a weight update mechanism that is slightly different from other neural architectures, as it doesn’t have an explicit loss function to backpropagate; rather, it relies on the weight update rule we saw in the previous article (remember?).
In our example training loop, we were invoking the backward()
step manually:
Fast.ai, on the other hand, expects a loss function, and will call loss.backward()
on the output Tensor
to back-propagate the error into each layer (usually with PyTorch’s autograd feature).
One way to solve this problem would be to manually redirect the loss.backward()
call to our model.backward()
. To do this, we’ll need a couple more classes:
Let’s create a placeholder loss function that always returns zero, just to make sure everything is working as intended:
Plus, we need to change our SOM to make sure that input data is available to update the weights during the backward()
step. To do so, we will store the last input batch and Best Matching Units in a dictionary.
Now that everything is in place, we can train our model with Fast.ai.
Visualizing loss with Fast.ai
Now that everything is handled by our Learner, we can easily see how our model is performing. First, let’s define an actual loss function:
Then, let’s pass it to our Learner and plot the loss:
Our SomLearner
now works perfectly! Unfortunately, the loss plot look really bad because we’re missing a key step in our SOM training process: hyperparameter scaling.
In the next article, we will have an in-depth look at how to perform different hyperparameter update strategies and change our model’s behaviour by using Callbacks.