Transfer learning: the dos and don’ts
If you have recently started doing work in deep learning, especially image recognition, you might have seen the abundance of blog posts all over the internet, promising to teach you how to build a world-class image classifier in a dozen or fewer lines and just a few minutes on a modern GPU. What’s shocking is not the promise but the fact that most of these tutorials end up delivering on it. How is that possible? To those trained in ‘conventional’ machine learning techniques, the very idea that a model developed for one data set could simply be applied to a different one sounds absurd.
The answer is, of course, transfer learning, one of the most fascinating features of deep neural networks. In this post, we’ll first look at what transfer learning is, when it will work, when it might work, and why it won’t work in some cases, finally concluding with some pointers at best practices for transfer learning.
What is transfer learning?
In their seminal paper on the subject, Pan and Yang (2010) give an elegant mathematical definition (see Subsection 2.2) of transfer learning. For our purposes, however, a much simpler definition will suffice: transfer learning uses what was learned for a particular task (sometimes called the source task) to solve a different task (the destination task). The assumption there is, of course, that the source and the destination task are sufficiently similar. That assumption is at the heart of transfer learning. We must therefore first understand why transfer learning works in the first place.
You might recall that your typical contemporary neural network — say, a multi-layer convolutional deep neural network (cDNN)–– consists of layers of neurons, each feeding forward its result according to the weights and biases of the network that were calculated as a result of training. This architecture was patterned after a particular process in the human brain that is rather similar to this, namely the Ventral Visual Stream.
The ventral stream begins in the primary visual cortex or V1, a part of the occipital lobe. V1 receives visual information from the lateral geniculate nucleus, a part of the thalamus that sends optical information to the occipital lobe while segmenting it into two types of information — inputs that are more useful for determining what you’re looking at (parvocellular layer of the LGN) and inputs more useful for determining where something is (magnocellular layer of the LGN). Both end up in the primary visual cortex, but the ventral visual stream is fed by the ‘what’ information, from the parvocellular layer of the LGN: slow, but sustained and detailed information about objects, originating from retinal ganglion cells called P cells.
The outputs (sometimes called projections) of P cells are then processed as they traverse the ventral visual stream, from V1 through V2, V4 and eventually the inferior temporal (IT) lobe. It was found using excitation experiments that while cells in V1 respond to pretty simple patterns (primarily, edges, their orientation, colours and spatial frequencies), V2 responds to more complex notions, including certain Gestalt phenomena such as ‘subjective contours’. Eventually, proceeding gradually to the inferior temporal lobe, neurons respond to increasingly complex patterns.
We see this replicated in deep neural networks. A technique called a Zeiler-Fergus deconvolutional network allows us to see what patterns a particular layer (often referred to as a ‘filter’ in this context) would most respond to (activation maximisation)— that is, the ‘kinds’ of structures that excite that particular neuron the most (you might be familiar with some of these shapes from Google’s DeepDream algorithms). As the layers increase, the complexity of shapes increases. While we see simple edges or even blocks of colours in the first layers, later layers reveal more complex patterns and final layers’ activations are often recognisable as the intended class. They also become increasingly semantic — a combination of lines are together weighted as a filter that recognises a triangle, combinations of triangle filters recognise two ears, in combination with other filters, it begins to recognise a dog’s face as distinct from a cat’s.
The idea of transfer learning is inherent in the fact that neural networks are layer-wise self-contained — that is, you can remove all layers after a particular layer, bolt on a fully connected layer with a different number of neurons and random weights, and get a working neural network. This is the basis of transfer learning. In transfer learning, we use what well-trained, well-constructed networks have learned over large sets, and apply them to boost the performance of a detector on a smaller (usually by several orders of magnitude!) data set.
To transfer or not to transfer
From the above, some facts emerge about the utility (and disutility) of transfer learning. The biggest benefit of transfer learning shows when the target data set is relatively small. In many of these cases, the model may be prone to overfitting, and data augmentation may not always solve the overall problem. Therefore, transfer learning is best applied where the source task’s model has been trained on a vastly bigger training set than could be acquired for the destination task. This may be the case because instances for a particular thing are hard to come by (e.g. when working on the synthesis or recognition of a voice of which only few samples exist), or where labelled instances are difficult to obtain (e.g. in the context of diagnostic radiology, labelled images are often hard to get, especially for rare conditions).
A model where there are approximately the same amount of data for each task might still benefit from transfer learning if there is a risk of overfitting, as it often occurs when the destination task is highly domain-specific. In fact, in training a large domain specific dCNN might be counterproductive, as it might overfit to the particular domain. It is overall then sometimes advisable to use transfer learning where the source and destination task’s training sets are of the same size.
In practice, in computer vision, it is very common to use gold standard networks, trained on massive image data sets––such as ImageNet’s 1.2 million images over 1,000 categories––as the point of departure for even quite domain-specific tasks, such as evaluating chest radiographs. A number of modern machine-learning packages, especially high-level packages like
keras, come with their own built-in model zoo, allowing easy access to a pre-trained dCNN. It is, then, only a question of removing the topmost layers, adding one or more new layers and retraining (fine-tuning) the model. The new model will immediately benefit for the weeks of laborious training that go into creating a model like
Inceptionv3. Overall, when used appropriately, transfer learning will give you a trifecta of benefits: a higher starting accuracy, faster convergence and higher asymptotic accuracy (the accuracy level to which the training converges). Recently, some fairly promising sites have sprung up that catalogue various pre-trained models––my favourites are ModelDepot and ModelZoo, the latter having a vast database filterable by framework and solution, including many pre-trained GANs.
Some best practices
- Most deep learning frameworks allow you to ‘selectively unfreeze’ the last n layers of a deep neural network, leaving the learned weights frozen on the rest. Overall, this feature is not really as useful as it first sounds. Experience suggests that time spent with carrying out thorough model introspection and trying to determine where to cut off the unfreeze is almost never worth it. One exception is if you’re training a very large network that might not fit into your GPU memory––in this case, resource constraints will decide how much you can afford to unfreeze.
- Instead of unfreezing specific layers, it’s probably a better idea to use a differential learning rate, where the learning rate is determined on a per-layer basis. The bottom layers will then have a very low learning rate, as these generalise quite well, responding principally to edges, blobs and other trivial geometries, whereas the layers responding to more complex features will have a larger learning rate. In the past, the
2:4:6rule (negative powers of 10) has worked quite well for me — using a learning rate of
10^-6for the bottommost few layers,
10^-4for the other transfer layers and
10^-2for any additional layers we added. I have also heard others using
2:3:4with different architectures. For
ResNets and their derivatives, I have always felt more comfortable with
2:3:4, but I have absolutely no empirical evidence to back this up with.
- Transfer learning by retraining any layers at all is not always a good idea. If the destination task is based on a small dataset that however is very similar to the one the network was trained on (e.g. recognising a new class of animal or vehicle not included in an ImageNet-trained ResNet50), leaving the weights frozen and putting a linear classifier on top of the output probabilities is likely to be more useful, yielding largely similar results without risking overfitting.
- When transferring to a task with a relatively large data set, you might want to train the network from scratch (which would make it not transfer learning at all). At the same time — given that such a network would be initialised with random values, you have nothing to lose by using the pretrained weights! Unfreeze the entire network, remove the output layer and replace it with one matching the number of destination task classes, and fine-tune the whole network.
- Finally: know what you’re re-using. It is not just good practice to be aware of what network you are transferring from––it’s essential deep learning tradecraft. While commonly trusted staples like
ResNethave time and again been proven to be well-built, solid networks, picking a network that is suitable for your task, and likely to be the most efficient of several alternatives if such exist, is exactly what deep learning professionals are paid the big bucks for. And if using a nonstandard network––and anything that doesn’t have hundreds of peer reviewed applications counts as non-standard––, reassure yourself that the network architecture is sound and indeed the best choice. Tools like Tensorboard’s network visualiser for TensorFlow graphs or Netscope for neural networks in Caffe’s prototxt format (if you just want to play around, Yang Lei has a fun drag-and-drop prototxt editor) give a good insight into the bowels of a neural network, and can be helpful in assessing the network architecture. When using an unfamiliar network, I have also found that introspection, including DeconvNet or other reverse engineering solutions to look at layer and neuron level activation maxima can be of great benefit.