Transfer Learning — Reusing a pre-trained Deep Learning model on a new task

Random Nerd
Analytics Vidhya
Published in
7 min readDec 18, 2018

After teaching Data Science to aspirants for over an year now, I guess I won’t be wrong in stating that Transfer Learning is a concept that bothers a lot of students. For most of them, it is more like “Wasn’t Machine Learning, and getting deeper with Deep Learning enough, that now I also need to understand some Transfer Learning?”. But in most scenarios, aspirants aren’t even aware of this term (Transfer Learning) because MOOCs don’t really use (except Andrew Ng) this term. Catch remains that expert instructors on online education portals don’t even skip this topic, just that they don’t explicitly use this term, leading to a big chunk of inquisitive learners often popping questions similar to “I understand how to train a model and have already done that but now how to use this pre-trained model for a new task?” So let us try to conceptually break down this topic in simple terms to fill in the gaps.

Throughout this article I won’t be reinventing the wheel or narrating something that hasn’t yet been explained at all by someone else. So the moment you feel “I could’ve done this on my own & don’t need him to tell me!”, my response would be “Please do that because we already have enough resources available for free in this digital age & there is no better approach than researching on our own”. But for those learners who struggle to find relevant information, we shall sequentially take steps to understand & guide ourselves to resources that can help us figure out this crucial concept.

Whenever we get into a situation where we find ourselves lost, the very first instinct is to do a Google Search, right? And often Wikipedia results top the search engine outcome, so let us start with what Wiki says:

Storing knowledge gained while solving one problem and applying it to a different but related problem” is highlighted because that is all we need to feed our grey cells with to begin with. Try to recreate it in your head with a simpler example, something like: As a cool new-era gaming millennial, if you know how to use a gamepad (handheld controller or joypad), you’re pretty set to begin video gaming, be it on Xbox or a PlayStation.

Hypothetically, your brain got initially trained on using a gamepad while gaming on a PlayStation, and later when you also purchased Xbox, your brain auto-transferred that gamepad controlling/usage model and adjusted it to the knowledge pipeline your brain has begun creating for gaming on Xbox. This made work very easy for your brain when learning Xbox, as partly it used a pre-trained (gamepad) model and quickly customized it as per requirements of Xbox gaming experience.

Now that we have a shady bit of conceptual understanding, let us think of why the heck is it intrinsic with Deep Learning models. So what is that special requirement for Deep Learning models to be effective? > Deep Learning models are pretty much like our girlfriends, who would never get satisfied irrespective of what-so-over (quantity of data) we do (feed) for them, so what we gonna get in return is just Cost (referring to Cost/Loss error). But if we get lucky, we might just find the right girl (Deep Learning model) for ourselves, and when this rare thing happens, we just can’t let her (pre-trained model) go, right?

Similarly, an excellent pre-trained model can be multi-purpose tool for us as it can help accomplish a lot of other things, and hence transfer learning is a crucial concept with advanced predictive neural net algorithms. > Another important aspect is that in real world, assembling a huge set of labelled data for every task is as difficult as accompanying our girlfriend for her shopping. But what if we get to know that our girlfriend’s BFF (pre-trained model) is also in town? Wouldn’t we instead prefer setting up both of them for this shopping spree (training of new model for some other task), while we enjoy Liverpool taking on Chelsea with some beer and pizzas?

Because this hypothetical girlfriend has already helped us comprehend to the need for Transfer Learning, let us move on to figure out what statistically happens during implementation of this concept. For example, in a computer vision scenario, deep learning algorithms generally try to discern edges in initial layers, shapes in middle layer and some requisite agenda-specific features in final layers. With Transfer Learning, we reuse initial and middle layers of our pre-trained model to re-train only final layers for our new task. This comes in very handy when we don’t have enough labelled data for our new task, and also because training algorithms like Neural networks, Bayesian networks or Markov logic networks are acutely computation exhaustive. Note that the process also has limitations which we shall discuss later on in this very article.

To expand our knowledge of Transfer Learning further on image recognition example that we just mentioned, let us quickly run through this insightful session by Andrew Ng where he demonstrates (in reference to ‘Radiology Diagnosis’ as new task) initialization of new set of weights for final layer in new model (using radiology image dataset), while keeping the initial & middle layers of our pre-trained model intact. With further superseding improvements, all the weights of our source (pre-trained) model get modified, rather than simply reinitializing and learning weights of the final classification layer in target (new) model. As training data availability increases, this additional model flexibility starts to pay dividends. Let us now listen to this living legend:

Quite informative, right? But somewhere it might be poking your mind that why all the references to Transfer Learning everywhere is mostly computer vision (image specifically) related. So let us switch gears and now look at Natural Language Processing (NLP). For years, NLP domain lacked an established referential dataset and source task for learning generalizable base models, thus the community of ‘fine-tuning pre-trained’ models wasn’t that big.

However, recent papers like Howard and Ruder’s “Universal Language Model Fine-tuning for Text Classification” & Radford’s paper “Improving Language Understanding by Generative Pre-Training” demonstrate promising model fine-tuning capability in the natural language domain. Although the source dataset varies across these papers, community seems to be standardizing on a “language modeling” objective as the go-to for training transferable base models. Obvious benefits come from the fact that raw text is abundantly available for every conceivable domain, and language modeling has the desirable property of not requiring labeled training data.

[MUST READ] A very thorough article on KDNuggets about “Effective Transfer Learning for NLP

I can give you high-end examples where Transfer Learning has been applied in Deep Learning and widely accepted by audience, but if you’re in the initial learning phase, architecture can quickly become complex, hence I won’t do that. Still just to name a few for your reference (in case you wish to read more on it, or probably just for interview purpose): > Inception model by Google [Simple image classification implementation using TensorFlow on Google AI Blog] >ResNet Model by Microsoft [Original Paper but start with this easy-to-understand implementation using PyTorch first] > Transfer Learning Using Twitter Data for Improving Sentiment Classification [Detailed paper for Turkish political news data on Twitter].

With all those hyperlinks that I have attached throughout this article, you would be good-to-go implementing Transfer Learning in selective scenarios using TensorFlow, PyTorch etc. majorly in Computer Vision and NLP domain with Deep Learning models. Prior to implementation, we need to understand how transferable our features of source model are, and are we doing enough to avoid common drawbacks. There are a couple of drawbacks as well which we need to be aware of, like: > We might be slightly constrained in terms of the architecture flexibility for our new dataset. > Learning Rate is generally kept very low for target task because we believe in weights of our source model, and don’t urge to distort them in any manner too quickly.

Overall we observed a conceptual understanding of the methodology, mathematical inclination, popular implementation domains, few in-practice examples, and finally common drawbacks of Transfer Learning till now. Moving on, the question is whether Transfer Learning in itself is such an effective resolution in every aspect that it can be a definite resolution to any Deep Learning problem? Unfortunately answer is NO because “in Statistical Learning there is no free lunch”.

This is where integrating other fine-tuning concepts get into execution pipeline. Here is a Machine Learning Blog by Sebastian Ruder which covers each of these concepts in detail, along with much more like Semi-Supervised Learning, Multitask Learning, Zero-shot learning, and much more, so I highly recommend to check out this blog. There is also another blog by Jason Brownlee which is worth following. And if you would like me to cover any of those concepts, please feel to let me know in Comments section. Thank You for your time and enjoy Machine Learning!

P.S: Any gender/relation references in this article are merely for adding slight humor & absolutely not meant to offend anyone. :)

--

--