Xfer: an open-source library for neural network transfer learning
Transfer learning is a set of techniques for reusing and repurposing already trained machine learning models in new situations. It brings particular advantages in the domain of deep learning, where training a model from scratch (rather than reusing an existing model) requires a lot of computational and data resources, as well as expertise. This blog post contains a quick overview of transfer learning through the introduction of Xfer, an open-source library that enables easy application and prototyping of transfer learning approaches.
Transfer learning
Neural networks are machine learning models that learn functions and patterns from data. They underpin numerous modern AI-enabled technologies with applications in conversational agents, self-driving cars, self-learning agents that play board games and many more.
In all of the above scenarios, the utility of a deep neural network comes from its ability to learn an association between input training data (e.g. images) and output training data (e.g. labels, such as “human” or “cat”), a task which is known as supervised learning. At deployment time, the deep neural network model will leverage these learned associations to predict the outputs of newly encountered inputs. Often, however, the conditions encountered in the real world during deployment may change with respect to the conditions considered when training the deep neural network. When this happens, the associations learned during training time become less useful. This problem is made worse if we wish to deploy a model trained in one scenario for use in a different situation or with different data. For example, a Pokémon hunter might want to use a deep neural network trained on natural animal images to build an application that recognizes Pokémon characters. The statistics of the 2D cartoonish Pokémon images are quite different from the statistics of the real-world animal images used for training; furthermore, the animal images training dataset does not contain labels relevant to the new task, i.e. names of Pokémon characters, therefore it will be impossible to get the predictions we want by deploying the pre-trained network as is.
We would like to avoid having to train a new neural network every time the training dataset changes or when a new use-case appears, because doing so is time and resource consuming. In fact, it is often the case that training a new neural network from scratch is infeasible due to limited training data (indeed, Pokémon encounters are rare, to continue the example from above). Furthermore, training deep neural networks requires a certain level of expertise and commitment. Instead, it is more beneficial to repurpose already trained neural networks. This is known as transfer learning [1, 2, 3]. In the example above, transfer learning would allow us to take a neural network trained on labelled animal images, automatically “extract” any knowledge that might be relevant to the new task (i.e. recognizing Pokémon characters) and build that knowledge into a new predictive model. This works because, although the statistics of the source domain’s data (animal images) and those of the target domain’s data (Pokémon images) are different, they nevertheless share common patterns (such as the arrangement of pixels representing eyes or tails), which can be exploited for transfer learning.
Automating transfer learning with Xfer
Xfer is a transfer learning library for MXNet designed for deep learning practitioners and researchers who wish to quickly:
- repurpose a pre-trained neural network for deployment in a new scenario/task
- prototype new neural network solutions based on existing architectures.
Given a new machine learning task, Xfer allows it to be solved using neural networks without training from scratch. Instead, Xfer allows you to repurpose a previously trained neural network. The library can be applied to arbitrary data and networks, including the common cases of image or text data.
What are the key motivations for using Xfer?
- Resource efficiency: you don’t have to train big neural networks from scratch (save on human and CPU/GPU resources).
- Data efficiency: by transferring knowledge, you can classify complex data even if you have very few labels.
- Easy access to neural networks: you don’t need to be a machine learning expert in order to leverage the power of neural networks. With Xfer you can easily reuse or even modify existing architectures and create your own solution.
- Utilities for feature extraction from neural networks.
- Rapid prototyping: Xfer’s ModelHandler module (see below) allows you to easily modify a neural network architecture, e.g. by providing one-liners for adding / removing / freezing layers.
- Uncertainty modeling: With the Bayesian neural network (BNN) or the Gaussian process (GP) repurposing methods (see below), you can obtain uncertainty in the predictions of the model which is repurposed for the target task.
How does Xfer work?
Xfer is comprised of two main components that work together and are detailed below.
Repurposing methods
These are methods for extracting the relevant information from a pre-trained (source) neural network and injecting it into another (target) machine learning model. The target model, called a Repurposer, can then be deployed to make predictions in the new target task. Xfer comes with five pre-built repurposing methods, grouped into two categories: meta-model based and fine-tuning based. These are explained with two minimal demos later in this blog post. The library is easily extensible, making it simple to add and benchmark new repurposing methods.
ModelHandler
ModelHandler is an Xfer module that allows for easy manipulation of MXNet deep neural network models. This is useful for transfer learning but also in its own right. For example, with one line of code, ModelHandler can alter an existing neural network architecture by means of removing/adding/freezing a collection of layers. This is useful for rapid prototyping, allowing quick building of neural networks even if you are not an MXNet or deep learning expert. For transfer learning, it is useful for creating target (repurposed) models which are modified versions of an original (pre-trained) neural network. Furthermore, ModelHandler can be used for feature extraction, obtaining the learned representations from within a pre-trained neural network, which is useful for downstream representation learning tasks. For transfer learning, it can then create repurposed models that leverage these representations. ModelHandler can also be used to create custom repurposing methods.
Minimal demo (meta-model based transfer)
After defining an MXNet source model and data iterators for your target task, you can perform transfer learning with just 3 lines of code:
The animated figure below demonstrates the operation of the meta-model family of repurposing methods.
In this example, ModelHandler is used to fetch the parameters, W, of a neural network which has been pre-trained on the source task. The next step is to pass the target input data through the source pre-trained neural network while keeping the source parameters W fixed. This process gives us features (representations) which are describing the target data but also contain information from the source task, since they were generated with the parameters W obtained from it. In the final step of the meta-model based transfer, the aforementioned features are used as inputs to a meta-model classifier, by invoking the Repurposer module. In the code given above, the meta-model used is ‘Lr’, which stands for Logistic Regression.
Notice that if the meta-model used is a Gaussian process or a Bayesian neural network (both built into Xfer), then one can obtain uncertainty estimates in the predictions made for the target task. The uncertainty-supporting repurposers, therefore, “know what they don’t know”. This can be important, for example, when the target task has very few labelled data, which is often the case in transfer learning applications; this is demonstrated in the image classification example later in this blog post.
Minimal demo (fine-tuning based transfer)
A different repurposing method family is based on fine-tuning. Specifically, this method allows the user to first refine the architecture of a pre-trained neural network, for example by adding or removing layers using the ModelHandler module, and then fine-tune the resulting network for the target task through gradient based optimization. Using ModelHandler, one can also experiment with custom fine-tuning based repurposers. For example, you can easily select which layers to transfer/freeze from a pre-trained neural network. A snapshot of ModelHandler’s functionality for transfer is shown below:
A graphic illustration of the above code is shown in the animated figure below.
Similarly to the previous demo, ModelHandler is used to fetch the parameters, W, of a neural network pre-trained on the source task. For example, the source model could be a pre-trained VGGNet. With ModelHandler, we are now able to refine the architecture of the source task neural network; in the example here, we add a new layer (with freshly initialized parameters) to the bottom of the architecture. The next step in the fine-tuning repurposing method is to train the new architecture in order to adapt it to the target task data. Xfer allows this fine-tuning step to happen by defining different treatments for the original parameters (shown as blue in the middle graph above) and the new parameters (shown as orange). This is because the original parameters already contain information about the source task, so they should not diverge too much from their learned values, whereas the new parameters have been initialized randomly, so they should be optimized with a larger learning rate.
Applications of deep transfer learning
Transfer learning with deep neural networks has shown great promise in computer vision, natural language processing, speech technologies and many other application areas. Below we discuss two popular applications that have particularly benefited from transfer learning.
Image classification
Deep neural networks, in particular convolutional neural networks (CNNs), have arguably introduced a paradigm shift in computer vision. These models are capable of leveraging large datasets in order to learn powerful representations. After training, a new image can be passed through the network and be represented using the feature space discovered by the CNN during training. Example feature spaces are visualized in the picture below.
As can be seen, the deep neural network creates features of different levels of abstraction. The key intuition behind transfer learning in image classification is that, at the right level of abstraction, features extracted from a deep neural network can generalize beyond a single task and beyond the dataset on which the training happened. To demonstrate this let us consider a simplified but rather extreme scenario: we have a deep neural network pre-trained on a large set of images, such as ImageNet, and wish to use it for classifying a very small set of hand-drawn images, as shown in the picture below:
First, let us see what predictions we obtain without repurposing, i.e. by using the pre-trained model as it is:
The predictions are, of course, not very satisfactory. For one, the pre-trained network might not contain labels such as “cheese” that exist in the target dataset of hand-drawn images. Secondly, the image statistics in the target dataset are clearly different from the larger, colorful, real-world images of ImageNet. Training a new neural network from scratch using a set of labelled images from the target domain is also not an option, since we have so few target domain images that overfitting will most likely occur.
Instead, we can repurpose the source model in order to leverage information (such as edges, basic shapes etc) that should be relevant to the target domain. After using Xfer to repurpose the source model and perform predictions, we obtain the result below:
Note that here we have used a meta-model repurposer which supports well calibrated uncertainty estimates in the predictions. The full repurposing pipeline for the above example can be found in the tutorial.
Text classification
Transfer learning through repurposing can be used in the text domain, for example to transfer knowledge across different languages or across different corpora. Let us demonstrate the latter case. In particular, we trained a CNN on 13 out of the 20 categories of text stored in the newsgroups text dataset. For the target task, we assume that we have access to a much smaller dataset with only 100 texts from the remaining 7 categories. As shown in the Xfer tutorial which contains the full pipeline, repurposing offers a 9% boost in accuracy (versus training from scratch) and is much quicker and cheaper to run (no GPU needed).
Considerations for transfer learning
Transfer learning enables the expansion of situations where deep neural networks are applicable and generally makes their application easier and more frugal. However, there are several considerations to be taken into account when designing or using a transfer learning tool and there are still open challenges. This blog post is particularly focused on the type of generic deep transfer learning implemented in Xfer, but it is worth pointing out that there are many other approaches, usually specializing in a narrower family of scenarios. Also notice that, although the discussion below could be helpful for guiding the application of transfer learning to a given problem, ultimately some experimentation will be needed in order to settle on the best transfer learning pipeline design (enabling quick prototyping is actually one of the motivations for Xfer’s design).
Choosing a transfer learning method
One obvious question is how to decide which transfer learning method to use. The level of similarity between source and target data distributions could provide hints to answer this question. For example, a high level of similarity would better support a meta-model based approach because the extracted features used by the meta-model would be highly relevant for the target task. Furthermore, if we need well calibrated uncertainty in the predictions of the target task, it would be desirable to apply a meta-model based repurposing method which supports that, such as a Gaussian process repurposer.
Measuring transferability
A related question, especially when using meta-model based repurposers, is which layers of the source network to use for transfer learning. Xfer allows you to use multiple layers, but to guide the selection we might need to think about the level of transferability between the source and target models/tasks. Measuring transferability is an important challenge, with both theoretical and practical implications. For example, a very low level of transferability can indicate that negative transfer is likely to overwhelm positive transfer, and therefore it is preferable to train a new model from scratch in the target domain.
The paper of Yosinski et al. 2014 examines the transferability question with a focus on image classification. The paper of Mou et al. 2016 focuses on transferability for NLP applications. There is also a body of research attempting to measure transferability in a more formal way, e.g. by formulating metrics and bounds.
When measuring transferability, it is important to consider the following notions of similarity:
- Statistical similarities of the source and target data domains. For example, the domain of colorful indoors images is more “similar” to the domain of colorful outdoors images than to the domain of scanned document images.
- Properties of the source and target training data. It is important to also think about the properties of the actual data at hand, which are finite size instantiations of the data domain mentioned above. For example, a tiny training set for the target domain might render the fine-tuning repurposing method prohibitive, due to the danger of overfitting.
Start using Xfer today
- Get Xfer from Github and start doing transfer learning for MXNet models
- Learn more about its functionality by reading Xfer’s Documentation
- Tutorial: Introduction and transfer learning for image data
- Tutorial: Transfer learning with automatic hyperparameter tuning
- Tutorial: Transfer learning for text data
- Tutorial: Creating your own custom repurposer
- Tutorial: xfer.ModelHandler for easy manipulation and inspection of MXNet models
Acknowledgements: Jordan Massiah, Keerthana Elango, Pablo G. Moreno, Nikos Aletras, Cliff McCollum, Neil Lawrence