Active Transfer Learning with PyTorch

Robert (Munro) Monarch
PyTorch
Published in
11 min readJan 11, 2020

Machine learning models can be adapted to predict their own errors and therefore trust that unlabeled data points will later get the correct human labels and no longer be errors. This article goes into detail about Active Transfer Learning, the combination of Active Learning and Transfer Learning techniques that allow us to take advantage of this insight, excerpted from the most recently released chapter in my book, Human-in-the-Loop Machine Learning, and with open PyTorch implementations of all the methods.

Before getting started

In my previous article written for PyTorch, Active Learning with PyTorch, I covered the building blocks for Active Learning. You should start there if you are not familiar with Active Learning and also see my articles on the two types of Active Learning, Uncertainty Sampling & Diversity Sampling, and Advanced Active Learning techniques to combine them:

There are also clear examples of all the Active Learning algorithms, including the new ones introduced in this article, in my free PyTorch library:

Ideally, you should personally try implementing the simpler Active Learning strategies before jumping into the more advanced methods in this article.

What is Transfer Learning?

Transfer Learning is the process of taking a Machine Learning model that was built for one specific task and adapting it to another task.

“We don’t need roads where we’re going”. Mechanical Transfer Learning: a car can be repurposed to become a time machine or (more achievable) a boat (Image Source: SF Chronicle)

Adapting technology from one use case to another is fun. I’ll never forget the thrill of looking out of a train window one day near San Francisco and seeing a car racing the train in the water of the Brisbane Lagoon. Someone had a converted a DeLorean from Back-to-the-Future fame into a hovercraft.

You can feel the same delight every time that a Machine Learning model that was built for one purpose is adapted to a completely new use case. If that use case happens to be Active Learning, then we are taking one of the most fun parts of Machine Learning and applying it to solve one the most important problems in Machine Learning: how can humans and AI solve problems together?

Transfer Learning in current Machine Learning typically means taking an existing neural model and then retraining the last layer (or last few layers) for a new task, which can be represented like this:

An example of Transfer Learning. We have model predict a label as “A”, “B”, “C”, or “D” and a separate dataset with the labels “W”, “X”, “Y”, and “Z”. Retraining just the last layer of the model the model is now able to predict labels “W”, “X”, “Y”, and “Z”.

The biggest advantage of Transfer Learning is that you need far fewer human labeled examples compared to if you were training a model from scratch, which means that you can get higher accuracy models with less data. If you come from a Computer Vision background you’ve probably used Transfer Learning for a model adapted from the ImageNet classification task and if you come from a Natural Language Processing background you’ve probably used Transfer Learning that adapts a pre-trained model like BERT.

Making your model predict its own errors

The new labels for Transfer Learning can be any categories that you want. This includes information about the task itself! This is the first of three core insights for Active Transfer Learning:

Insight 1: You can use transfer learning to ask your model where it is confused, by making it predict its own errors.

This article covers three variations of Active Transfer Learning, the simplest being a binary “correct/incorrect” task to predict where a model might make errors:

Active Transfer Learning for Uncertainty Sampling. Validation items are predicted by the model and bucketed as “Correct” or “Incorrect” according to whether they were classified correctly or not. The last layer of the model is then retrained to predict whether items are “Correct” or “Incorrect”, effectively turning the two buckets into new labels.

There are three steps to this process:

  1. Apply the model to a validation data set and capture which validation items were classified correctly and incorrectly. This is your new training data: your validation items now have an additional label of “Correct” or “Incorrect”.
  2. Create a new output layer for the model and train that new layer on your new training data, predicting your new “Correct”/”Incorrect” labels.
  3. Run your unlabeled data items through the new model and sample the items that are predicted to be “Incorrect” with the highest confidence.

PyTorch makes this incredibly simple with the ability to pass the activation of every neuron back to other processes, allowing us to build our Active Transfer Learning model on top of our original model. Let’s assume we have a simple network with one hidden layer, with this forward() function:

def forward(self, feature_vec, return_all_layers=False):    hidden1 = self.linear1(feature_vec).clamp(min=0)
output = self.linear2(hidden1)
log_softmax = F.log_softmax(output, dim=1)
if return_all_layers:
return [hidden1, output, log_softmax]
else:
return log_softmax

We can then iterate through our validation data and assign each validation item a value of whether it was “correct” or “incorrect”, and store its hidden layer as input to our new model:

correct_predictions = [] # validation items predicted correctly
incorrect_predictions = [] # validation items predicted incorrectly
item_hidden_layers = {} # hidden layer of each item, by id
for item in validation_data:
# assume "item" contains id, label & features of each data point
id = item["id"]
label = item["label"]
feature_vector = item["feature_vector"]

hidden, logits, log_probs = model(feature_vector, True)
item_hidden_layers[id] = hidden # record hidden layer value

if is_correct(label, log_probs):
correct_predictions.append(item)
else:
incorrect_predictions.append(item)

We can then train a new model to predict “Correct” or “Incorrect”, using the hidden layer as the new input (feature) vector. Let’s assume we’ve called that new model correct_model in our code.

After that new model is trained, the only (slightly) tricky part is that we need to get predictions from both models for the unlabeled data: one prediction to the get the hidden layer from the first model and then a second prediction on the new “Correct/Incorrect” model:

active_transfer_preds = []with torch.no_grad():    #A
v=0
for item in unlabeled_data:
id = item["id"]
label = item["label"]
feature_vector = item["feature_vector"]

# get prediction from initial model
hidden, logits, log_probs = model(feature_vector, True)

# get predictions from correct/incorrect model
correct_log_probs = correct_model(hidden, False)

At this point in the code correct_log_probs has the probability that the unlabeled item will be predicted correctly. By sampling the items with the lowest confidence of being predicted correctly, you are sampling items that should be the highest value for a human to review and apply a label to.

This code is a slightly simplified version of the code in the advanced_active_learning.py file in the free PyTorch library: https://github.com/rmunro/pytorch_active_learning/blob/master/advanced_active_learning.py

You can run it immediately on that use case — identifying disaster-related messages — with the following command:

python advanced_active_learning.py --transfer_learned_uncertainty=10

This will run the entire process and then present you with the 10 most uncertain items for you to provide the correct label.

At this point, the model might not be any better than the simpler Uncertainty Sampling algorithms, so it is also a good idea to implement the simpler methods as a baseline, first. But don’t quit yet: this is the first step towards building a more powerful version of this algorithm.

The biggest advantage that we get from Transfer Learning over the simpler methods is that is makes it much easier for our Active Learning strategy to be Adaptive. A common problem with Active Learning strategies is that they will sample unlabeled items that are all from one part of the feature space and therefore lack diversity, so Diversity Sampling methods like clustering are needed to avoid this problem. There are Advanced Active Learning techniques that combine Uncertainty Sampling and Diversity Sampling individually, but the following methods in this article have the advantage of combining the two into a single architecture.

Often, it is difficult to get human labels in real-time and it is more practical to sample a large amount of unlabeled items and have them labeled as a batch. So in these cases, Active Transfer Learning for Adaptive Representative Sampling can be adaptive during the sampling process even though we don’t yet know what the labels will be.

Active Transfer Learning for Representative Sampling

For many real-world use cases, your data is changing over time. For example, in the autonomous vehicle use case there are always new types of objects that are encountered and the scope might be expanded, like driving on open water in addition to roads.

Representative Sampling is a form of Diversity Sampling that aims to sample unlabeled items that are most like the application domain of a Machine Learning model, relative to the current training data.

Because our sampled items will later get a human label, we can assume that they become part of the training data without needing to know what the label is.

Active Transfer Learning for Adaptive Representative Sampling. Because our sampled items will later get a human label, we can assume that they become part of the training data without needing to know what the label is. Our new model begins with items from our “Training” and “Application” distributions as the respective labels, and the “Application” items are incrementally added to the “Training” items as we sample them.

The steps look like this:

  1. Take validation data from the same distribution as the training data and give it a “Training” label. Take unlabeled data from our target domain and give it an “Application” label.
  2. Train a new output layer to predict the Training/Application labels, giving it access to all layers of the model.
  3. Apply new model to the unlabeled data and sample the items that are most confidently predicted as “Application”.
  4. Assume that the newly sampled items will later get label and become part of the training data: change the label of those items from “Application” to “Training”, and then repeat from Step 2.

This is an incredibly powerful algorithm because it avoids only sampling items from one part of the feature space, sampling a diverse set of items before any human labeling.

Insight 2: You can assume that an unlabeled item will later get a label, even if you don’t know what the label is yet.

Active Transfer Learning for Adaptive Sampling (ATLAS)

The most sophisticated use of Active Transfer Learning is Active Learning for Adaptive Sampling (ATLAS). It brings together the principles from the two earlier models in this article: predicting uncertainty and adapting to the data before any human labels are added.

This is where our time-travel analogy helps. Imagine that you have converted your car into a time machine but you have to drive it down the road at 88 miles per hour in order to travel in time. You can send that car into the future knowing that a road will later be there, even though you don’t yet know what that road will look like or what else will be around the road. You could then start making plans for the future that take into account that the car will be there, even without the full knowledge of that context.

We can do the same thing with our models, assuming that we have knowledge of the data that we will later label and use that partial knowledge to take future actions, namely, sampling even more data for human review:

Active Transfer Learning for Adaptive Sampling. Because our sampled items will later get a human label, we can assume that the model will later label those items correctly, because models are typically the most accurate on the actual items that they trained on. To begin with, validation items are predicted by the model and bucketed as “Correct” or “Incorrect” according to whether they were classified correctly or not. The last layer of the model is then retrained to predict whether items are “Correct” or “Incorrect”, effectively turning the two buckets into new labels. We then apply that to the unlabeled data, predicting whether each item will be “Correct” or “Incorrect”. You can then sample the most likely to be “Incorrect”. Then, we can assume that those items will later get labeled and become part of the training data, which will be labeled correctly by a model that predicted on that same data. So, we can take those sampled items, change their label from “Incorrect” to “Correct”, and then retrain our final layer(s) on the new dataset.

The steps look like this:

  1. Apply the model to a validation data set and capture which validation items were classified correctly and incorrectly. This is your new training data: your validation items now have an additional label of “Correct” or “Incorrect”.
  2. Create a new output layer for the model and train that new layer on your new training data, predicting your new “Correct”/”Incorrect” labels.
  3. Run your unlabeled data items through the new model and sample the items that are predicted to be “Incorrect” with the highest confidence.
  4. Assume that the newly sampled items will later get labels and that the model will later predict those items correctly after training on them: change the label of those items from “Incorrect” to “Correct”, and then repeat from Step 2.

By combining the Active Transfer Learning techniques for Uncertainty Sampling and for Adaptive Representative Sampling, we now have a model that can predict its future state. It doesn’t know what the labels will be for the items that are initially sampled, but it knows that they will get a label then it can make smarter sampling decisions based on the anticipated future event.

Insight 3: you can assume that your model will correctly predict the label of unlabeled items that are similar to items that will later get a label, even if you don’t know what the labels are yet.

This code is in the same file as above, advanced_active_learning.py in the same free PyTorch library: https://github.com/rmunro/pytorch_active_learning/blob/master/advanced_active_learning.py

You can run it from commands line with:

python advanced_active_learning.py --atlas=10

Active Transfer Learning Cheatsheet

Here’s a one-page cheatsheet that you can reference when building the algorithms in this article:

Active Transfer Learning cheatsheet

For quick reference, you can download a PDF version of the cheatsheet here: http://www.robertmunro.com/Active_Transfer_Learning_Cheatsheet.pdf

This article and the cheatsheet are excerpted from my book, Human-in-the-Loop Machine Learning: https://www.manning.com/books/human-in-the-loop-machine-learning. The chapters of my book are being published as they are written and the chapter containing Active Transfer Learning techniques including ATLAS is available now!

Robert Munro | January 2020 | @WWRob

Notes on Active Transfer Learning architectures:

See the book for a deeper dive into the architecture choices you can make for Active Transfer Learning. Here’s a handful of notes to get started:

  1. It is (mathematically) equivalent to remove the last layer and retrain a new layer (as in the images above) or instead to take the output from the last hidden layer(s) and use that as input to a new model (as in the code examples). I think that the former is more intuitive visually, but the latter is less error-prone to code because it is purely additive and you don’t need to worry about how changing your model could impact other parts of your code. If you prefer either implementation of Transfer Learning in your own code, that’s fine. That’s also true if you want to play around with tuning existing layers with the new data/labels instead of removing layers entirely: that is compatible with the Active Transfer Learning techniques shared here.
  2. Note that the Representative Sampling example uses all the hidden layers and also adds an additional new layer, while the Uncertainty Sampling & ATLAS examples are a simple binary prediction after the final hidden layer. This is by design as a good starting point for your architectures, but you can experiment different architectures in all cases. The reasoning behind these starting points is that the final layers of our model don’t distinguish low activation from items that are not well represented in the data from items that are well represented in the data but have features that are mostly irrelevant to the model in its current state. So, Representative Sampling should perform better with information from earlier layers. By contrast, the Uncertainty Sampling and ATLAS examples only use the last layer, because the last layer of the model is already optimized to minimize uncertainty and so it is unlikely to find more signal in earlier layers and will be more prone to overfitting if you do include earlier layers.
  3. You might consider multiple models and/or variable predictions from a single model via Monte-Carlo Sampling. These examples rely on validation data from the same distribution as your training domain and you could easily overfit the particular items in that validation set. If you’re splitting your training data 90:10 into training:validation, like in the code examples here, then one easy method is to repeat this for all 90:10 combinations. Note that for the Uncertainty Sampling & ATLAS examples, you’re only creating a single new binary predictor, so you don’t need too much data for the results to be robust. That’s a nice property of these models: a single additional binary prediction is easy to train with relatively little data and often no manual tuning.
  4. Active Transfer Learning can work on more complicated tasks like Object Detection, Semantic Segmentation, Sequence Labeling, and Text Generation. Almost any type of neural model can add a new layer (or head) to predict a “Correct/Incorrect” label or a “Training/Application” label, so this is a very versatile technique. I cover the best way to approach these use cases and more in an upcoming chapter of the book!

--

--

Robert (Munro) Monarch
PyTorch

Private/Global Machine Learning at @Apple | Runs @BayAreaNLP | Wrote bit.ly/human-in-the-l… | Prev @StanfordNLP @AWSCloud | Opinions my own | 🚲🌍 | they/he