Multi-task Deep Learning Experiment using fastai Pytorch

This post is an abstract of a Jupyter notebook containing a line-by-line example of a multi-task deep learning model, implemented using the fastai v1 library for PyTorch. This model takes in an image of a human face and predicts their gender, race, and age.

The notebook wants to show:

  1. an example of a multi-task deep learning model;
  2. the multi-task model makes better predictions than the individual model; and
  3. how to use the fastai library to easily implement the model.

The Jupyter notebook is working and runnable, so you can run and change the code if you like (at least it’s fun to randomly sample faces — some of them you may recognize).

As an abstract, this post does not contain the code. Full notebook with code is here:

Multi-task learning

First, some quick introduction to multi-task learning.

What it is

wikipedia: Multi-task learning (MTL) is a subfield of machine learning in which multiple learning tasks are solved at the same time, while exploiting commonalities and differences across tasks.

In this post, multiple learning tasks are the three tasks to predict gender, race (both classification tasks), and age (a regression task), respectively.

Why do it

wikipedia: This can result in improved learning efficiency and prediction accuracy for the task-specific models, when compared to training the models separately.

In this post,

  • Regarding improved learning efficiency: we run the multi-task model once instead of three time for the single-task models (all four models have similar run time).
  • Regarding improved prediction accuracy: we will show that the multi-task model cut average age prediction error in half from 10 years to 5 years while holding the gender and race prediction accuracy almost the same.

Why it works

wikipedia: Multi-task learning works because regularization induced by requiring an algorithm to perform well on a related task can be superior to regularization that prevents overfitting by penalizing all complexity uniformly.

We will not cover the math here. Wikipedia provides some good resources.


fastai

We use the library fastai v1 for PyTorch to implement our model. This software is built by the organization also called fastai and founded by Jeremy Howard and Rachel Thomas to “make neural nets uncool again”. Its free on-line deep learning courses are considered by many researchers the best out there. The fastai software library “simplifies training fast and accurate neural nets using modern best practices”, in other words, it’s really easy to use and it works really well.


The data and the problem

We use UTKFace data for this experiment.

UTKFace dataset is a large-scale face dataset with long age span (range from 0 to 116 years old). The dataset consists of over 20,000 face images with annotations of age, gender, and ethnicity. The images cover large variation in pose, facial expression, illumination, occlusion, resolution, etc. This dataset could be used on a variety of tasks, e.g., face detection, age estimation, age progression/regression, landmark localization, etc.

We will try to predict gender, race (both classification tasks), and age (a regression task), respectively, based on the image. In the setting of multi-task learning, we could think of age as the central task partly because it is more difficult.

Explore data

Some sanity check on the data.

Gender counts seems even.

Counter({'f': 11314, 'm': 12391})

Race counts. Notice this might not be the same race distribution as the data in your own problem.

[('White', 10078),
('Black', 4526),
('Indian', 3975),
('Asian', 3434),
('Others', 1692)]

Gender per race counts. Seems relatively even.

[(('Asian', 'f'), 1859),
(('Asian', 'm'), 1575),
(('Black', 'f'), 2208),
(('Black', 'm'), 2318),
(('Indian', 'f'), 1714),
(('Indian', 'm'), 2261),
(('Others', 'f'), 932),
(('Others', 'm'), 760),
(('White', 'f'), 4601),
(('White', 'm'), 5477)]

Let’s look at the distribution of age. Looks like babies and 20-year olds like their pictures taken the most.

Let’s look at some pictures for sanity check.


Single-task models

For comparison, we will first build single-task models for the three tasks: gender, race, and age predictions. We build the multi-task model in the next section.

gender model

We achieve a validation accuracy of 94%.

race model

We achieve a validation accuracy of 86%.

age model

The age prediction task is a regression problem. We achieve a 10-year average prediction error.


Multi-task model

We now combine the three problems together to make a multi-task model.

Also, we run the multi-task model once instead of three time for the single-task models (all four models have similar run time).

Inspect results of the multi-task model

Now let’s look at the result on validation set to make sure things work.

We can also take a look at the true v.s. predicted age.


Closing comments

In this post we showed:

  1. an example of a multi-task deep learning model: predicting gender, race (both classification tasks), and age (a regression task), respectively, based on the face image;
  2. that the multi-task model makes better predictions than the individual model; and
  3. how to use the fastai library to implement the model.

Hopefully this convinces you to

  1. try building a multi-task model when related labels are available (it is very common in real life data)
  2. try using fastai library for it as it is easy and provides best practices out of the box.

We haven’t considered a lot of things.

  1. We’ve only scratched the surface of multi-task learning. There are a large variety of techniques to learn and experiment. Wikipedia and Sebastian Ruder’s article are good reference.
  2. We kept the model building part simple on purpose to focus on the multi-task part. We didn’t experiment with model architecture including the pretrained backbone and the top layers, we didn’t tune dropout, and we didn’t consider progressively increasing image size for training, a effective technique taught by fastai. In a real world application these are worth considering and testing, and fastai library usually makes it easy.