<blink>You read it right, HMTL: a Hierarchical Multi-Task Learning model.</blink>
There is a rising tide 🌊 in NLP in particular but also everywhere in Deep-Learning and Artificial Intelligence which is called Multi-Task Learning!
I’ve been experimenting with Multi-Task Learning for almost a year now and the result is HMTL, a model that beats the state-of-the-art on several NLP tasks and which will be presented in the very selective AAAI conference. We’ve released the paper 📃 and the training code ⌨️ so be sure to check them out.
Now, what is Multi-Task Learning?
Multi-Task Learning is a general method in which a single architecture is trained towards learning several different tasks at the same time.
Here is an example: we’ve built a pretty nice online demo that runs HMTL interactively so let’s try for yourself ! 🎮
Traditionally, a specific model was independently trained for each of these NLP tasks (Named-Entity Recognition, Entity Mention Detection, Relation Extraction, Coreference Resolution).
In the case of HMTL, all these results are given by a single model, in a single forward pass!
But Multi-Task Learning is more than just a way to reduce the computation by using a single model instead of several models.
Multi-Task Learning (MTL) can be used to encourage the model to learn embeddings that can be shared between different tasks. One fundamental motivation behind Multi-Task Learning is that related (or loosely related) tasks can benefit from each other by inducing richer representations.
In this post, I will try to give you a sense of how powerful and versatile multi-task learning can be for NLP applications. First, I will share some intuition on why multi-task learning is an exciting trend with pointers to relevant references if you want to go deeper 📚, then we will see how we can quickly build a Multi-Task trainer in Python 🐍 ️and I will gather a few lessons we learned 👨🎓👩🎓 while conducting this research project that led to state-of-the-art performances.
A brief introduction to cool 😎 stuff in Multi-Task Learning and why it matters.
In a classical Machine Learning setup, we train a single model by optimizing a single loss function to perform a single specified task. While focusing the training on a (single) task of interest is still the common approach for a lot of problems in Machine Learning, it does not leverage the information that other related (or loosely related) tasks can bring to further improve the performance.
As an analogy, Usain Bolt —probably one of the greatest sprinter of all time🏃🏿, nine-time Olympic gold medalist 🥇 and still title-holder of several world records 🏆(November 2018)— used to have a very intense and broad training, and spend a significant part of his training not actually running, but training for other exercises. For instance, he lifts weights, box jumps, bounds, etc. These exercises are not directly related to running but develop his muscular strength and explosiveness to be better at the ultimate goal: sprint.
“Multi-task Learning is an approach to inductive transfer that improves generalization by using the domain information contained in the training signals of related tasks as an inductive bias. It does this by learning tasks in parallel while using a shared representation; what is learned for each task can help other tasks be learned better.” R. Caruana 
In Natural Language Processing, MTL was first leveraged in neural-based approaches by R. Collobert and J. Weston . The model they proposed is an MTL instance in which several different tasks (with task-specific layers) rely on the same shared embeddings which are trained by the supervision of the different tasks.
Sharing the same representation among different tasks can sounds like a really low-level signal/way to transfer relevant knowledge from one task to another, but it has proven itself to be really useful in particular for its capacity to improve the generalization ability of a model.
While it is common and straight-forward to fix in advance how the information will be transferred across tasks, we can also let the model decide by itself what parameters and layers it should share, along with the layers that are best suited for a given task as showed in Ruder et al., 2017 .
More recently, these ideas of a shared representation have re-emerged under the spotlights notably through the quest for “Universal Sentence Embeddings” which could be used across domains and are not task specific (c.f. Conneau et al. ). Several attempts rely on MTL: for instance, Subramanian et al.  observe that to be able to generalize over a wide range of diverse tasks, it is necessary to encode multiple linguistic aspects of the sentence. They proposed Gensen, an MTL architecture with a shared encoder representation followed by several task-specific layers. The 6 tasks used in this work are weakly related and range from Natural Language Inference, to Machine Translation through Constituency Parsing.
To dive deeper in the current state-of-the-art in sentence embeddings, you can refer to our detailed blog post on Universal Word and Sentence Embeddings📚
In short, Multi-Task Learning is getting a lot of traction and is becoming a must known for a broad variety of problems in NLP 🔠, but obviously also in Computer Vision 👀. Benchmarks such as the GLUE benchmark (General Language Understanding Evaluation, Wang et al. ) have been introduced recently to evaluate the generalization ability of MTL architectures and more generally Language Understanding models.
For a more comprehensive overview of MTL in NLP, you can refer to this blogpost of S. Ruder. 📚
Multi-Task Learning in Python 🐍
🔥 Now let’s try some code to see how MTL looks in practice.
A very important piece of a multi-task learning scheme is the Trainer: how should we train the network? In which order should we tackle the various tasks? Should we switch task periodically? Should all the tasks be trained for same number of epochs? There is no clear consensus today on all the questions, and many different training procedures have been proposed in the literature 🙇🏻 so let’s be pragmatic!
First, let’s start with a simple and general piece of code that will be agnostic to the training procedure we pick 👌:
- Select a task (whatever your selection algorithm is).
- Select a batch in the dataset for the chosen task (randomly sampling a batch is usually a safe choice).
- Perform a forward pass.
- Propagate the loss (backward pass) through the network.
These 4 steps should suit most of the use cases 🙃.
During the forward pass, the model computes the loss of the task of interest. During the backward pass, the gradients computed from the loss are propagated through the network to optimize both the task-specific layers and the shared embeddings (and all other relevant trainable parameters).
At Hugging Face, we love 💛 the AllenNLP library that is being developed by the Allen Institute for AI. It’s a powerful and versatile tool for conducting research in NLP that combine the flexibility of PyTorch with smart modules for data loading and processing that were carefully designed for NLP.
If you haven’t checked it out yet, I highly recommend that you do. The team made an amazing work with the on-boarding tutorials, so you have no excuses ! 😬
I will now show a simple piece of code for creating a MTL trainer based on AllenNLP.
Let’s first introduce a class
Task which will contain task-specific datasets and all the attributes directly related to the tasks.
Now that we have our class
Task , we can define our Model.
Creating a model in AllenNLP is pretty easy. Just make your class inherit from the
allennlp.models.model.Model class. Lots of useful methods will be automatically supplied such as
get_regularization_penalty() which applies penalties (e.g. L1 or L2 regularizations) during the training phase.
Let’s talk about the two main methods that we need:
get_metrics(). These methods respectively compute the forward pass (up to the loss computation) and the training/evaluation metrics for the current task during the training.
Our important element for Multi_task Learning is to add a specific argument
task_name which will be used to specify the current task of interest during training. Let’s have a look:
Now we said that a crucial point in MTL is choosing the training task order. The most straight-forward way to select a task it to sample uniformly 🎲 a task after each parameter update (forward + backward passes). This algorithm was used in a several prior works like Gensen that we mentioned earlier.
But we can be a little bit smarter: we choose a task randomly following a probably distribution in which each probability of choosing a task is proportional to the proportion of training batches for a task compared to the total number of batches. This sampling procedure turns out to be pretty useful as we will see later, and is a pretty elegant way to prevent catastrophic forgetting.
The following snippet of code implements this procedure. Here,
task_list denotes a list of
Task on which we want to train our model.
Let’s try our MTL trainer.
The following snippet of code illustrates how we can assemble the elementary pieces we’ve building so far 💪.
train() method will iterates over the tasks according to the probability distribution over the tasks and will optimize the parameters of the MTL model update after update.
Note that it is always a good idea to include a stopping condition on the training based on the validation metrics (cf
_val_metric_decreases in class
Task). For instance, we can stop training when the validation metrics stop improving during
patience number epochs. This is usually performed after each training epoch. We haven’t done it but, you should be able to easily modify the previous snipped code to take into account these enhancements, or simply have a look at the more complete training code.
There are many other techniques you can use to train a MTL model that I don’t have the cover in depth in this blogpost ⌛️ so I’ll just point out a few references you’ll find worth reading now that you have the basic ideas:
- Successive regularization: one of the main issues that arises when training a MTL model is catastrophic forgetting where the model abruptly forgets part of the knowledge related to a previously learned task as a new task is learned. This phenomenon is especially recurring when multiple tasks are trained sequentially. Hashimoto et al.  introduce successive regularization: it prevents the parameter updates from being too far from the parameters at the previous epoch by adding an L2 penalty on the loss. In this particular setting, the MTL trainer does not switch of task after parameter update but go through the whole training dataset for the task of interest.
- Multi-Task as Question Answering: recently, McCann et al.  introduced a new paradigm to perform Multi-Task Learning. Each task is reformulated as a question-answering task while a single unified model (MQAN) is jointly trained to answer 10 different tasks considered in this work. MQAN achieves state-of-the-art results in several tasks such as the WikiSQL semantic parsing task. More generally, this work discusses the limits of single-task learning and the relations of Multi-Task Learning with Transfer Learning.
Improving the state-of-the-art 📈 in semantic tasks: A Hierarchical Multi-Task Learning model (HMTL)
Now that we’ve talked about the training scheme, how can we develop a model that would get the most benefit out of our multi-task learning scheme?
In the recent work that we will present at AAAI in January, we propose to design such a model in a hierarchical way.
More precisely, we build a hierarchy between a set of carefully selected semantic tasks in order to reflect the linguistic hierarchies between the different tasks (see also Hashimoto et al. ).
The intuition behind such a hierarchy is that some tasks may be simple and require a limited amount of modification to the input while others may require more knowledge and a more complex processing of the inputs.
The set of semantic tasks we considered is composed of Named Entity Recognition, Entity Mention Detection, Relation Extraction and Coreference Resolution.
The model is organized hierarchically as illustrated on the figure on the left with “simpler” tasks being supervised at lower level of the neural network and “more complex” task supervised at high layer of the neural net.
In our experiments we observed that these tasks can benefit from each other through Multi-Task Learning:
- the combination of these 4 tasks leads to state-of-the-art performance 📊 on 3 of the tasks (Named Entity Recognition, Relation Extraction and Entity Mention Detection).
- the MTL framework considerably accelerates the speed of training ⏱ compared to single task training frameworks.
We also analyzed the embeddings that are learned and shared in HMTL. For the analysis, we used SentEval, a set of 10 probing tasks introduced by Conneau et al. . These probing tasks aim at evaluating how well sentence embeddings are able to capture a wide range of linguistic properties (syntactic, surface and semantic).
Our analysis indicated that the lower level shared embeddings already encode a rich representation and that as we move from the bottom to the top layers of the model, the hidden states of the layers tend to represent more complex semantic information.
This concludes our introduction to Multi-Task Learning. If you want to learn more about our hierarchical model (HMTL), you now have all the tools you need to dive into our paper 📃 and the training code ⌨️.
We also built a nice online demo 🥑 so you can try HMTL by yourself ! 🎮
If you like 👍 this post, don’t forget to clap 👏 and tweet about it🦆. If you don’t like this post, tell your favorite aunt 👵 to give some claps 👏.
 ^ R. Collobert and J. Weston, A Unified Architecture for Natural Language Processing: Deep Neural Networks with Multitask Learning, 2008
 ^ Sebastian Ruder, J. Bingel, I. Augenstein and A. Søgaard, Learning what to share between loosely related tasks, 2017
 ^ A. Conneau, D. Kiela, H. Schwenk, L. Barrault and A. Bordes, Supervised Learning of Universal Sentence Representations from Natural Language Inference Data, 2017
 ^ S, Subramanian, A. Trischler, Y. Bengio and C. J. Pal, Learning General Purpose Distributed Sentence Representations via Large Scale Multi-task Learning, 2018
 ^ K.Hashimoto, C. Xiong, Y. Tsuruoka and R. Socher, A Joint Many-Task Model: Growing a Neural Network for Multiple NLP Tasks, 2017
 ^ B. McCann, N. S. Keskar, C. Xiong, R. Socher, The Natural Language Decathlon: Multitask Learning as Question Answering, 2018
 ^ A. Conneau, D. Kiela, SentEval: An Evaluation Toolkit for Universal Sentence Representations, 2018