Transfer Learning

Atulanand
CodeX
Published in
14 min readMar 23, 2023

Deep learning has provided extraordinary advances in problem spaces that are poorly solved by other approaches. This success is due to several key departures from traditional machine learning that allow it to excel when applied to unstructured data. Today, deep learning models can play games, detect cancer, talk to humans, and drive cars.

But the differences that make deep learning powerful also make it costly. You may have heard that deep learning success requires massive data, expensive hardware, and even more expensive elite engineering talent.

Here, we’re going to look at transfer learning, a related technique that enables the transfer of knowledge from one task to another. Rather than developing an entirely customized solution to your problem, transfer learning allows you to transfer knowledge from related problems to help solve your custom problem more easily. By transferring that knowledge, you are taking advantage of the expensive resources that were used to acquire it — training data, hardware, researchers — without the incurring the cost. Let’s see how and when this approach is effective.

Why deep learning is different?

Transfer learning is not a new technique, nor is it specific to deep learning, but it is newly exciting in light of the recent progress in deep learning. So first, it’s important to spell out the ways in which deep learning is different than traditional machine learning.

Deep learning operates at a lower level of abstraction

Machine learning is a way for machines to automatically learn functions which assign predictions or labels to numerical inputs, i.e. data.

The difficult part here is determining exactly how the function produces the output from the provided input. Without any restrictions on the function, the possibilities (and complexities) are endless. In order to simplify this task, we usually impose some type of structure on the function — based on the type of problem we’re solving, domain expertise, or simply trial and error. That structure defines a type of machine learning model.

In theory, there are an infinity of possible structures, but in practice most machine learning use cases can be solved by applying one of only a handful of structures: linear models, ensembles of trees, and support vector machines make up a solid core. The data scientist’s job is then to choose the correct structure from this small set of possible structures.

These models are available as black box objects from a variety of mature machine learning libraries, and can be trained in just a few lines of code. For example, you can train a random forest model using Python’s scikit-learn like this:

clf=RandomForestClassifier()
clf.fit(training_data,target)
pred=clf.predict(test_data)

Deep learning, however, operates at a lower level. Rather than choosing among a small, finite set of model structures, deep learning allows practitioners to compose arbitrary structures. The building blocks are modules or layers that can be thought of as basic, fundamental data transformations. This means that we need to open up the black box when applying deep learning, instead of treating it as fixed by the algorithm.

This allows more powerful models to be built, but it also adds an entirely new dimension to the model building process. Composing these transformations effectively can be a difficult process, despite the volume of published deep learning research, practical guidelines and folk wisdom.

Consider an extremely simple Convolutional Neural Network image classifier, defined here in the popular deep learning library PyTorch.

class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1=nn.Conv2d(1,10,kernel_size=5)
self.conv2=nn.Conv2d(1,10,kernel_size=5)
self.conv2_drop=nn.Droupout2d()
self.fc1=nn.Linear(320,50)
self.fc2=nn.Linear(50,10)

def forward(self,x):
x=F.relu(F.max_pool2d(self.conv1(x),2))
x=F.relu(F.max_pool2d(self.conv2_drop(self.cov2(x)),2))
x=x.view(-1,320)
x=F.relu(self.fc1(x))
x=F.dropout(x,training=self.training)
x=self.fc2(x)
return F.log_softmax(x,dim=1)

Because we are working with low level building blocks, we can change a single component of the model (e.g. F.relu to F.sigmoid, for instance). This gives us an entirely new model architecture that may yield dramatically different results. And the possibilities are literally endless.

Deep learning is not yet well-understood

Even given a fixed neural network architecture, training is notoriously difficult. First, deep learning loss functions are not in general convex, which means that training does not necessarily yield the best possible solution. Second, deep learning is still very new and many of its components are not well-understood. For example, batch normalization has received attention recently because its inclusion in some models seems to be critical for good results, but experts cannot agree on why. Researcher Ali Rahimi caused some controversy when he went so far as to liken deep learning to alchemy at a recent machine learning conference.

Automatic Feature Engineering

The added complexity in deep learning enables a technique called representation learning, which is why it’s often stated that neural networks do “automatic feature engineering.” In short, instead of having a human hand-engineer helpful features from a dataset, we build models in such a way that they can learn whatever features are necessary and helpful for the task at hand. Offloading feature engineering onto the model is immensely powerful, but comes with the cost of models that require massive amounts of data and, consequently, massive computing power.

What you can do about it?

Deep learning is so complex in comparison to other machine learning methods, that it can seem too overwhelming to incorporate into your business. For resource-constrained organizations, this feeling is magnified.

For organizations that truly need to operate on the bleeding edge, it may indeed be necessary to hire experts and purchase specialized hardware. But this is not necessary in many cases. There are ways to apply it effectively without making enormous investments. This is where transfer learning comes in.

Transfer learning enables the transfer of knowledge from one machine learning model to another. These models may be the result of years of research into model structure, trained on colossal datasets, and optimized over years of compute time. With transfer learning you can get much of the benefit of this work for none of the cost!

What is transfer learning?

Most machine learning tasks start with zero knowledge, meaning that the structure and parameters of the model begins as random guesses. This is what we mean when we say a model is learned from scratch.

A cat detector model trained from scratch starts by guessing. It gradually learns what a cat is by aggregating common patterns across the many different cats it

In this situation, everything the model learns comes from the data that you show it. But is this the only way of solving a problem? In some cases, it might seem like it.

A cat detector model is likely useless in unrelated applications, like fraud detection. It only knows how to make sense of cat pictures, not credit card transactions.

But in other cases, it seems like we should be able to share information between tasks.

A cat detector is helpful in related tasks, like cat facial location. The detector should already know how to detect whiskers, noses, and eyes — all things that are useful in locating the cat’s face.

This is the essence of transfer learning: taking a model that has learned how to do one task very well and transferring some (or all) of that knowledge to a related task.

This makes sense when we examine our own learning experiences; we regularly transfer skills learned in the past to more quickly learn new skills. For example, someone who has learned to throw a baseball does not need to completely re-learn the mechanics of throwing a projectile to learn how to throw a football. These things are inherently related, and the ability to do one of them well naturally translates into the ability to do the other.

In the machine learning world, there is perhaps no better example than the field of computer vision over the last five years. It is now exceedingly rare to train an image model from scratch. Instead, we start with a pretrained model that already knows how to classify simple objects such as cats and dogs and umbrellas. Models that learn to classify images do so by first learning to detect general image features, such as edges, shapes, text, and faces. The pretrained model has these fundamental skills (as well as more specific skills, such as distinguishing between dogs and cats).

Transfer learning needs less training data

When you re-use your favorite cat detection model in a new, cat-related task, your model already has “the wisdom of one million cats,” which means that you don’t need to use nearly as many cat pictures to train the new task. Reducing the size of training data can enable you to train in settings where there is very little data available and where more data may be expensive or impossible to obtain, and can also allow you to train models faster on cheaper hardware.

Models learned by transfer learning generalize better

Transfer learning improves generalization, or the ability of the model to perform well on data that it wasn’t trained on. This is because pre-trained models are purposefully trained on tasks that force the model to learn generic features that are useful in related contexts. When the model is transferred to a new task, it will be difficult to overfit to the new training data, since the model will only learn incrementally from a very general knowledge base. Building a model that generalizes well is one of the hardest and most important parts of machine learning.

The transfer learning training process is less brittle

Starting with a pre-trained model also helps overcome the frustrating, brittle, and confusing process of training a complex model with millions of parameters. Transfer learning reduces the number of trainable parameters by as much as 100%, making training more stable and easier to debug.

Transfer learning makes deep learning easier

Finally, transfer learning helps make deep learning more accessible, since you don’t need to be an expert yourself to obtain expert level results. Consider the popular image classification model Resnet-50.

How was that particular architecture chosen? It is the result of years of research and experimentation from various deep learning experts. Within this complicated structure there are 25 million weights, and optimizing these weights from scratch can be near impossible without extensive knowledge of each of the model’s components. Fortunately, with transfer learning you can re-use both the complicated structure and optimized weights, significantly lowering the barrier to entry for deep learning.

Multi-task learning

Transfer learning is one of a family of knowledge sharing techniques for training machine learning models that has proven to be extremely effective. Currently, the two most interesting of these techniques are transfer and multi-task learning. In transfer learning, a model is trained on a single task and then used as a starting point for a related task. In learning the related task, the original transferred model will learn to specialize in the new task, without concern of how that affects its performance on the original task. In multi-task learning, a single model learns to do multiple tasks at once, and the evaluation of its performance depends on how well it learns to do all those tasks. For a more detailed discussion on the benefits of multi-task learning and when it may be useful, see our latest research on multi-task learning.

In Machine Learning (ML), we typically care about optimizing for a particular metric, whether this is a score on a certain benchmark or a business KPI. In order to do this, we generally train a single model or an ensemble of models to perform our desired task. We then fine-tune and tweak these models until their performance no longer increases. While we can generally achieve acceptable performance this way, by being laser-focused on our single task, we ignore information that might help us do even better on the metric we care about. Specifically, this information comes from the training signals of related tasks. By sharing representations between related tasks, we can enable our model to generalize better on our original task. This approach is called Multi-Task Learning (MTL) and will be the topic of this blog post.

Multi-task learning has been used successfully across all applications of machine learning, from natural language processing and speech recognition to computer vision and drug discovery. MTL comes in many guises: joint learning, learning to learn, and learning with auxiliary tasks are only some names that have been used to refer to it. Generally, as soon as you find yourself optimizing more than one loss function, you are effectively doing multi-task learning (in contrast to single-task learning). In those scenarios, it helps to think about what you are trying to do explicitly in terms of MTL and to draw insights from it.

Even if you’re only optimizing one loss as is the typical case, chances are there is an auxiliary task that will help you improve upon your main task. Rich Caruana summarizes the goal of MTL succinctly: “MTL improves generalization by leveraging the domain-specific information contained in the training signals of related tasks”.

Over the course of this blog post, I will try to give a general overview of the current state of multi-task learning, in particular when it comes to MTL with deep neural networks. I will first motivate MTL from different perspectives. I will then introduce the two most frequently employed methods for MTL in Deep Learning. Subsequently, I will describe mechanisms that together illustrate why MTL works in practice. Before looking at more advanced neural network-based MTL methods, I will provide some context by discussing the literature in MTL. I will then introduce some more powerful recently proposed methods for MTL in deep neural networks. Finally, I will talk about commonly used types of auxiliary tasks and discuss what makes a good auxiliary task for MTL.

Two MTL methods for Deep Learning

To make the ideas of MTL more concrete, we will now look at the two most commonly used ways to perform multi-task learning in deep neural networks. In the context of Deep Learning, multi-task learning is typically done with either hard or soft parameter sharing of hidden layers.

Hard parameter sharing

Hard parameter sharing is the most commonly used approach to MTL in neural networks. It is generally applied by sharing the hidden layers between all tasks, while keeping several task-specific output layers.

Hard parameter sharing greatly reduces the risk of overfitting. In fact, Baxter showed that the risk of overfitting the shared parameters is an order N — where N is the number of tasks — smaller than overfitting the task-specific parameters, i.e. the output layers. This makes sense intuitively: The more tasks we are learning simultaneously, the more our model has to find a representation that captures all of the tasks and the less is our chance of overfitting on our original task.

Soft parameter sharing

In soft parameter sharing on the other hand, each task has its own model with its own parameters. The distance between the parameters of the model is then regularized in order to encourage the parameters to be similar. Duong, L., Cohn, T., Bird, S., & Cook, P. (2015) for instance use the 𝑙2 norm for regularization, while use the trace norm.

The constraints used for soft parameter sharing in deep neural networks have been greatly inspired by regularization techniques for MTL that have been developed for other models, which we will soon discuss.

Why does MTL work?

Even though an inductive bias obtained through multi-task learning seems intuitively plausible, in order to understand MTL better, we need to look at the mechanisms that underlie it. Most of these have first been proposed by Caruana (1998). For all examples, we will assume that we have two related tasks A and B, which rely on a common hidden layer representation F.

Implicit data augmentation

MTL effectively increases the sample size that we are using for training our model. As all tasks are at least somewhat noisy, when training a model on some task A, our aim is to learn a good representation for task A that ideally ignores the data-dependent noise and generalizes well. As different tasks have different noise patterns, a model that learns two tasks simultaneously is able to learn a more general representation. Learning just task A bears the risk of overfitting to task A, while learning A and B jointly enables the model to obtain a better representation F through averaging the noise patterns.

Attention focusing

If a task is very noisy or data is limited and high-dimensional, it can be difficult for a model to differentiate between relevant and irrelevant features. MTL can help the model focus its attention on those features that actually matter as other tasks will provide additional evidence for the relevance or irrelevance of those features.

Eavesdropping

Some features G are easy to learn for some task B, while being difficult to learn for another task A. This might either be because A interacts with the features in a more complex way or because other features are impeding the model’s ability to learn Through MTL, we can allow the model to eavesdrop, i.e. learn G through task B. The easiest way to do this is through hints i.e. directly training the model to predict the most important features.

Representation bias

MTL biases the model to prefer representations that other tasks also prefer. This will also help the model to generalize to new tasks in the future as a hypothesis space that performs well for a sufficiently large number of training tasks will also perform well for learning novel tasks as long as they are from the same environment.

Regularization

Finally, MTL acts as a regularizer by introducing an inductive bias. As such, it reduces the risk of overfitting as well as the Rademacher complexity of the model, i.e. its ability to fit random noise.

Continuous learning

While multi-task learning allows us to retain the knowledge across many tasks without suffering a performance penalty on our source tasks, this is only possible if all tasks are present at training time. For each new task, we would generally need to retrain our model on all tasks again.

In the real world, however, we would like an agent to be able to deal with tasks that gradually become more complex by leveraging its past experience. To this end, we need to enable a model to learn continuously without forgetting. This area of machine learning is known as learning to learn, meta-learning, life-long learning, or continuous learning. It has seen some recent developments in the context of RL most notably by Google DeepMind on their quest towards general learning agents and is also being applied to sequence-to-sequence models.

Zero-shot learning

Finally, if we take transfer learning to the extreme and aim to learn from only a few, one or even zero instances of a class, we arrive at few-shot, one-shot, and zero-shot learning respectively. Enabling models to perform one-shot and zero-shot learning is admittedly among the hardest problems in machine learning. At the same time, it is something that comes naturally to us humans: Toddlers only need to be told once what a dog is in order to be able to identify any other dog, while adults can understand the essence of an object just by reading about it in context, without ever having encountered it before.

Recent advances in one-shot learning have leveraged the insight that models need to be trained explicitly to perform one-shot learning in order to achieve good performance at test time, while the more realistic generalized zero-shot learning setting where training classes are present at test time has garnered attention lately.

--

--

Atulanand
CodeX
Writer for

Data Scientist @Deloitte | Infosys Certified Machine Learning Professional | Google Certified Data Analytics