Transfer learning: a savior in production ML

Jason Victor
Advancing on Chaos
Published in
4 min readDec 13, 2017
Transfer learning is kind of like algorithmic brainstorming.

TL;DR: when you don’t have enough data and need to train a complex model like a deep net, find public data sets similar to your problem and use basic transfer learning to get a model that generalizes well.

Transfer learning — or using knowledge of one machine learning problem to better solve another — has recently come into vogue due to its utility in training deep nets. While a lot of focus has gone into doing “magic tricks” with transfer learning, I think its greatest utility is in turning real-life problems with insufficient data into tractable problems that give life to models with genuine business value.

The problem

I was recently confronted with a machine learning problem — one of those problems that sounds easy to a layman, but in reality, is deeply complex. In this case, we had all the problems: the input domain (text documents) was unwieldy, the classes were highly imbalanced, the data (mainly the target variable) was noisy, and worst yet, there wasn’t much labeled data to be had.

This is the kind of problem that, with our current tech, one would expect deep nets to deliver the state-of-the-art, but I tested that assumption first. Random forests, gradient boosted trees, and even random kitchen sinks all failed miserably, whether I was using word embeddings, bag-of-words, the hashing trick, character histograms, or anything else. I had to use deep learning.

And initial experimentation showed that even though a deep net — in this case, a bidirectional LSTM — immediately overfit my data set, it actually had decent out-of-sample performance, and its output matched human intuition nicely. These results were not what I was expecting — at all. (Which is a problem, but that’ll be another post for another day.)

How could optimization heuristics get enough traction on so many weights with such little data and still not completely overfit? I wondered.

Manual inspection showed each time I trained one of these nets that it performed well against the human “smell test,” but each had quirks — things it missed that others had figured out, so to speak. So I decided to bag them.

Bagging the nets and taking the average prediction performed okay, but actually worse than any individual net. (Again, undesired surprise.) And, when I stacked an XGBoost classifier on top of the prediction probabilities to try to be smarter about the final step, the performance was absymal. When I replaced XGBoost with an L1-regularized logistic regression, I was able to get my best performance — I assume because a weighted average of predictions is a naturally sensible way to go about the ensemble.

But with the data set at the size it was, none of these techniques worked right away. This was partially fixable by using oversampling techniques on the meta-model outputs used to train the final classifier. I used SMOTE, and it brought the performance of the linear classifier to an acceptable level, slightly better than the mode of the individual nets’ performance.

But while this improved the quantitative evaluations of the model — its ROC/AUC and PR curve — it no longer passed the human smell test. Outputs had become more extreme, there was less variation, and much of the very human-feeling intelligence of the original neural net appeared lost.

Transfer learning to the rescue

With the bagging failing me, I figured my small data set was the biggest hindrance, so I began to look to transfer learning and data augmentation. I didn’t find many good ideas for data augmentation on text documents (beyond the obvious, like scrambling sentences) so started to investigate transfer learning.

Transfer learning is a research subject unto itself, and there are many ways to go about it. But here’s the easy button approach to doing transfer learning:

  1. Find pubic data sets that are similar to your problem in terms of the kinds of features that would be useful to extract.
  2. Train a deep net on the public data sets until it converges to the expected loss.
  3. Freeze the first fully-connected layer of the net — the one that functions as a feature extractor — so it can’t be retrained. (In Keras, this means you need to recompile the model, but that won’t reset the weights.)
  4. Run a few epochs of training with your real training data.

Boom. Transfer learning. If the public data sets were well chosen and well-suited to the problem, you’ll notice that your net converges quickly to a decent accuracy in the first epoch of training on your real data. Since the feature extractors (the first layer) have already been trained, only the subsequent layers need to learn, and they can learn using features they already know, and which are well-suited to the problem domain.

This can dramatically change the nature of the solution the optimizer discovers, because it is forced into a region of parameter space that is known to have relevant solutions, where relevant is defined in a very loose sense. (More on this later, too.)

What does this mean for production ML?

It means that the tiny data sets you have at work can actually be used for something — even something so complex as to require a bidirectional LSTM.

“We don’t have the data” is no longer an excuse to not do data science if the domain is well served by neural nets. While transfer learning may not be as well studied outside the context of deep nets, I believe the future is bright, and we’ll one day see transfer learning in play across a wide variety of problem domains.

--

--