Transfer Learning in a nutshell

How to train neural net on a small dataset?


Deep learning is the most exciting area of research nowadays. Autonomous cars, machine translators, voice assistants or visual search engines are on nearly everyone’s lips. It is even more interesting since the breakout of these methods started only a couple of years ago. Every day, new cutting-edge solutions come out and the results become more and more impressive.

Why is it happening right now? The recent success of deep learning models can be encapsulated in 3 major points:

  1. Huge improvement in easily accessible computing power;
Image for post
Image for post
Standard neural net performs millions of operations per iteration.

2. Creation of various datasets consisting of thousands or millions of samples from different domains;

Sample from the classical MNIST dataset.

3. Invention of very efficient and accurate algorithms.

Image for post
Image for post
ResNet architecture.

In order to accomplish satisfying results that could be applied to the real-world and business problems all of the mentioned aspects need to be fulfilled. Fortunately, in most cases neither the computational power nor proficient algorithms is problematic. The powerful virtual machines equipped with GPUs can be rented via AWS or GCP for the reasonable price. On the other hand, research of the suitable model and its implementation may be time-consuming and requires experienced and skilled experts, but it is still affordable.

However, it is often very difficult, expensive or even impossible to provide the sufficient amount of data for a specific problem. For example, gathering a large set of medical images would be very costly as it requires a group of experts to label it and it is not easy to aquire a lot of photos of similar kind. This is the place, where the idea of Transfer Learning enters the stage. It can be summarized in one sentence:

Make use of the knowledge obtained earlier and apply it to the related problem.

More formally, transfer learning is the idea of overcoming the isolated learning paradigm and utilizing knowledge acquired for one task to solve related ones.

Idea behind

Let’s draw analogy to human learning:

  • Problem: company HumanLearning needs a person fluent in a new programming language called DeepScript.
  • Candidates: experienced programmer familiar with 6 other languages or graduate of non technical studies with no experience in programming.
  • Finding: it is much easier to adapt already gained knowledge to slightly different problems than learn from scratch.

And that is exactly the idea of transfer learning: we use models pretrained on huge datasets from one domain to solve the related problems in the similar area.

Image for post
Image for post
Use already gained knowledge for quick adaptation.

The most common approach in computer vision field is to employ powerful, very deep convolutional neural nets pretrained on ImageNet. The latter is enormous dataset consisted of ~1.2M images from 1000 different classes. As the best performing models are made of hundreds of millions parameters (current leader FixEfficientNet-L2 consists of 480M parameters), their training process takes a few days or even weeks, while raising the bill up to thousands of dollars. Fortunately, most of the top pretrained models are available either in the most popular frameworks like PyTorch, Tensorflow or can be downloaded from the public Git repositories. For instance, you should definitely take a look at below repository as it contains up-to-date models which can be easily implemented in any project.

How does it work?

Image for post
Image for post
Transfer learning scheme.

Model used for transfer learning is usually called the backbone model. During the adaptation procedure, the top layers, so called classifier head, are replaced with the randomly initialized layers of the dimension corresponding to the number of classes from a new dataset. Whereas the earlier layers of a backbone contain more generic features (e.g. edge or color blob detectors), the latter are progressively more specific to the details of the classes.

Thus, the classifier head is often trained alone with the frozen bottom layers. This approach is common if the dataset does not have many samples in order to prevent overfitting. We may try to fine-tune all the backbone’s layers from the beginning as well. As it happens in most of the deep learning problems, there is no certain way of training the model — it is the matter of experimentation. As a wise man once said…

There is no such thing as a free lunch.

When to use transfer learning?

Transfer learning is an optimization, a shortcut to saving time or getting better performance. In general, it is not obvious that there will be a benefit to using transfer learning in the domain until the model has been developed and evaluated.

There are 3 possible benefits to look for when using transfer learning:

  1. Higher start. The initial skill (before refining the model) on the source model is higher than it otherwise would be.
  2. Higher slope. The rate of improvement of skill during training of the source model is steeper than it otherwise would be.
  3. Higher asymptote. The converged skill of the trained model is better than it otherwise would be.
Image for post
Image for post
Benefits of transfer learning.

Ideally, you would see all three benefits from a successful application of transfer learning.

Experiments on benchmark dataset

Recently, with my colleagues at Yosh, we performed experiments on StanfordCars dataset. It is popular benchmark for transfer learning problems as it is consists of ~16k images from 196 classes of cars. The data is split into 8144 training and 8041 testing images.

Image for post
Image for post
StanfordCars sample.

Firstly, we tried to apply EfficientNet-b3 as a backbone model for transfer learning. It is one of the smallest models in the family of EfficientNets, whose representatives hold leadership in many image classification tasks. However, we obtained very poor results of ~80.5% in comparison with the current state-of-the-art scores, which exceed 94% of accuracy. As usual, the devil is in the details. In order to achieve satisfying scores, a lot of different methods should be trialed.

Image for post
Image for post
Process of finding the right set of methods and hyperparameters.

In the list below, some useful methods are mentioned. You should experiment with at least someof them in your own project, to be competitive with SOTA (State Of The Art) results:

  • Different activation functions (Mish, Squish, LeakyReLU)
  • Advanced augmentation policies (AutoAugment, RandomAugment)
Image for post
Image for post
Random augmentations with cutout on a single image.
  • Various optimizers (SGD, Adam, RAdam, Ranger)
  • Learning Rate scheduling methods (OneCycle, CosineAnnealing)
  • CutMix and Cutout transforms
Image for post
Image for post
CutMix method.
  • Different loss functions (Label-smoothing Cross Entropy)
  • Metric learning methods (Contrastive loss, Triplet loss)

Only after implementing some of the above methods, we managed to obtain the results close to the best ever reported.

Image for post
Image for post
Best results obtained.

Well, it is much easier said than done! There is no single recipe for training models in transfer learning as there are no two identical datasets. However, the results can be truly impressive.

This is the first article in a series concerning deep learning in various areas (Computer Vision, Natural Language Processing, Recommender Systems) in collaboration with the colleagues from Feel free to comment and ask questions. Stay tuned!

Written by

Deep Learning enthusiast. ML Researcher at LinkedIn:

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store