Transfer Learning in a nutshell
How to train neural net on a small dataset?
Deep learning is the most exciting area of research nowadays. Autonomous cars, machine translators, voice assistants or visual search engines are on nearly everyone’s lips. It is even more interesting since the breakout of these methods started only a couple of years ago. Every day, new cutting-edge solutions come out and the results become more and more impressive.
Why is it happening right now? The recent success of deep learning models can be encapsulated in 3 major points:
- Huge improvement in easily accessible computing power;
2. Creation of various datasets consisting of thousands or millions of samples from different domains;
3. Invention of very efficient and accurate algorithms.
In order to accomplish satisfying results that could be applied to the real-world and business problems all of the mentioned aspects need to be fulfilled. Fortunately, in most cases neither the computational power nor proficient algorithms is problematic. The powerful virtual machines equipped with GPUs can be rented via AWS or GCP for the reasonable price. On the other hand, research of the suitable model and its implementation may be time-consuming and requires experienced and skilled experts, but it is still affordable.
However, it is often very difficult, expensive or even impossible to provide the sufficient amount of data for a specific problem. For example, gathering a large set of medical images would be very costly as it requires a group of experts to label it and it is not easy to aquire a lot of photos of similar kind. This is the place, where the idea of Transfer Learning enters the stage. It can be summarized in one sentence:
Make use of the knowledge obtained earlier and apply it to the related problem.
More formally, transfer learning is the idea of overcoming the isolated learning paradigm and utilizing knowledge acquired for one task to solve related ones.
Let’s draw analogy to human learning:
- Problem: company HumanLearning needs a person fluent in a new programming language called DeepScript.
- Candidates: experienced programmer familiar with 6 other languages or graduate of non technical studies with no experience in programming.
- Finding: it is much easier to adapt already gained knowledge to slightly different problems than learn from scratch.
And that is exactly the idea of transfer learning: we use models pretrained on huge datasets from one domain to solve the related problems in the similar area.
The most common approach in computer vision field is to employ powerful, very deep convolutional neural nets pretrained on ImageNet. The latter is enormous dataset consisted of ~1.2M images from 1000 different classes. As the best performing models are made of hundreds of millions parameters (current leader FixEfficientNet-L2 consists of 480M parameters), their training process takes a few days or even weeks, while raising the bill up to thousands of dollars. Fortunately, most of the top pretrained models are available either in the most popular frameworks like PyTorch, Tensorflow or can be downloaded from the public Git repositories. For instance, you should definitely take a look at below repository as it contains up-to-date models which can be easily implemented in any project.
Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train…
How does it work?
Model used for transfer learning is usually called the backbone model. During the adaptation procedure, the top layers, so called classifier head, are replaced with the randomly initialized layers of the dimension corresponding to the number of classes from a new dataset. Whereas the earlier layers of a backbone contain more generic features (e.g. edge or color blob detectors), the latter are progressively more specific to the details of the classes.
Thus, the classifier head is often trained alone with the frozen bottom layers. This approach is common if the dataset does not have many samples in order to prevent overfitting. We may try to fine-tune all the backbone’s layers from the beginning as well. As it happens in most of the deep learning problems, there is no certain way of training the model — it is the matter of experimentation. As a wise man once said…
There is no such thing as a free lunch.
When to use transfer learning?
Transfer learning is an optimization, a shortcut to saving time or getting better performance. In general, it is not obvious that there will be a benefit to using transfer learning in the domain until the model has been developed and evaluated.
There are 3 possible benefits to look for when using transfer learning:
- Higher start. The initial skill (before refining the model) on the source model is higher than it otherwise would be.
- Higher slope. The rate of improvement of skill during training of the source model is steeper than it otherwise would be.
- Higher asymptote. The converged skill of the trained model is better than it otherwise would be.
Ideally, you would see all three benefits from a successful application of transfer learning.
Experiments on benchmark dataset
Recently, with my colleagues at Yosh, we performed experiments on StanfordCars dataset. It is popular benchmark for transfer learning problems as it is consists of ~16k images from 196 classes of cars. The data is split into 8144 training and 8041 testing images.
Firstly, we tried to apply EfficientNet-b3 as a backbone model for transfer learning. It is one of the smallest models in the family of EfficientNets, whose representatives hold leadership in many image classification tasks. However, we obtained very poor results of ~80.5% in comparison with the current state-of-the-art scores, which exceed 94% of accuracy. As usual, the devil is in the details. In order to achieve satisfying scores, a lot of different methods should be trialed.
In the list below, some useful methods are mentioned. You should experiment with at least someof them in your own project, to be competitive with SOTA (State Of The Art) results:
- Different activation functions (Mish, Squish, LeakyReLU)
- Advanced augmentation policies (AutoAugment, RandomAugment)
- Various optimizers (SGD, Adam, RAdam, Ranger)
- Learning Rate scheduling methods (OneCycle, CosineAnnealing)
- CutMix and Cutout transforms
- Different loss functions (Label-smoothing Cross Entropy)
- Metric learning methods (Contrastive loss, Triplet loss)
Only after implementing some of the above methods, we managed to obtain the results close to the best ever reported.
Well, it is much easier said than done! There is no single recipe for training models in transfer learning as there are no two identical datasets. However, the results can be truly impressive.
This is the first article in a series concerning deep learning in various areas (Computer Vision, Natural Language Processing, Recommender Systems) in collaboration with the colleagues from Yosh.ai. Feel free to comment and ask questions. Stay tuned!