Learning multiple tasks with gradient descent
Learning multiple tasks is an important requirement for artificial general intelligence. I apply a basic convolutional network to two different tasks. The network is shown to exhibit successful multi-task learning. The frequency of task switchover is shown to be a crucial factor in the rate of network learning. Stochastic Gradient Descent is shown to be better at multi-task learning than Adam. Narrower networks are more forgetful.
It’s often remarked about how machine learning struggles to simultaneously learn multiple tasks (“continual learning”):
The lack of algorithms to support continual learning thus remains a key barrier to the development of artificial general intelligence. 
I’m curious to see these problems first-hand, so in this blog post I’m going to run a series of multi-task learning experiments to demonstrate the limits and properties of today’s machine learning.
- What does task forgetting look like?
- Can today’s architectures learn multiple tasks?
- How does task duration affect task forgetfulness?
- Are there any simple ways to reduce task forgetfulness?
- Which optimiser forgets tasks the least?
I’m using a very simple network for these experiments:
- Input: 3 channel image, 28x28, training batches of 100
- Two hidden layers: 1 convolutional layer, 1 fully connected layer
- ReLU activation, max pooling, batch normalisation and dropout are used
- Output: Softmax classifier over 10 classes
- Stochastic Gradient Descent is used to train the network with a learning rate of 0.1
First of all, let’s look at how this network performs on two simple tasks:
MNIST is a popular toy dataset of hand-written digits. There are 60,000 labelled training images and 10,000 labelled test images. Given its popularity and familiarity, it’s a good task for this investigation.
Here’s how the network performs:
Our simple network performs adequately after 2,000 training cycles, achieving a test accuracy of 96% and after 25k cycles achieving 97% test accuracy. State of the art is close to 100%, however for the purpose of this investigation the network performs well enough.
Baseline: ImageNet Dogs
I’ve extracted ten dog breeds from the Stanford Dogs dataset (a subset of ImageNet) to train the network on. [# of test and train examples]
The network quickly manages to overfit the data (i.e. the training accuracy hits almost 100%) . The training does not generalise well: The test accuracy achieves 33% after 24k training cycles. This is better than random guessing (1/10 classes= 10%) however nowhere near state of the art (96% in CIFAR-10).
For this investigation of multi-task learning, the network’s performance relative to this baseline is more interesting than its absolute performance. Therefore this performance is adequate for our purposes.
Learning multiple tasks
Now let’s train and test the network against both tasks.
- Every 2,000 training cycles the network will be fed a different task (e.g. the other of the two)
- We’ll measure the training accuracy of the current task
- We’ll measure the test accuracy against each task and of performing all tasks (which is equivalent to the average of the individual task test accuracies)
The network receives each MNIST greyscale image as a three channel image, two of which are zeros. Dog breed images are fed in as three channel images.
- Training accuracy quickly rises to near 100% during each task training cycle
- The network partially forgets each previous task
- Test accuracy improves over time
- The network progressively forgets each task less each switchover until almost not forgetting happens
- It appears the network learns to co-habit the two tasks’ training over time. This could be because only the co-habitable training weights survive each task switchover, leading to sort of natural selection of those training weights.
Let’s compare the multi-task training with the previous baseline single-task training:
- After 12 task switchovers (24k cycles), multi-task training performs near or better than single task training (dog breeds: 31.4% vs 33.2%, MNIST 97.5% vs 97.1%)
- Contrary to the general consensus, a simple machine learning model using SGD is able to learn multiple tasks quite well
- The task switchover may act as an injection of noise into the system that helps it explore more maxima, resulting in the better test accuracy for MNIST.
Is multi-task learning improved by having a task signal?
Let’s now explore whether the network can perform better given a signal of which task is being performed (which could be given by another part of the network, and it is likely the human brain has such a signal).
The network is adjusted to now have four input channels. The first channel receives greyscale MNIST images (or zeros if not training that task) and channels two through four receive dog images (or zeros if not training that task).
Let’s perform the same training as earlier, for a task time of 5k steps, and compare with/without having this task signal:
- Having a task signal reduces MNIST forgetfulness (e.g. at training step 15k their is a 10% test accuracy improvement on that task)
- Having a task signal improves Dog breed training
- However, having a task signal seems to lead to faster forgetting of dog breed test accuracy.
- Overall, with a task signal the combined test accuracy is higher early on but later performs no better. [todo: make this more rigorous]
(Later consider: experiment with having task specific dropout / 0 nodes / task one hot residual signal. check why dropout works e.g. does it factor into gradients and does it require uniform distribution)
Multi-task learning vs optimizer
Let’s see how the Adam optimizer performs versus Stochastic Gradient Descent on our two-task experiment setup:
- Adam generally achieves better test accuracies, and achieves them more rapidly than SGD
- However, this leads to over-fitting of the current task, to the detriment of per-task test accuracy and overall test accuracy
- SGD forgets previous tasks less than Adam, and subsequently achieves higher accuracy on all task tests
Multi-task learning vs training time
It could be the case that the multi-task learning shown in the previous experiments is simply because the network does not have long enough to forget the previous task. That is, thanks to the slow learning rate, the network is averaging the training effect from each task. This effectively trains the network on a dataset that includes both tasks.
To test this suspicion, I’ll increase the time spent per task up to 20x longer. This experiment uses stochastic gradient descent and the task signal input format described ealier.
How time spent training each task affects test accuracy:
- Faster task switching leads to higher test accuracy given the same amount of training cycles
Putting the tasks onto timeline relative to task-switchover (each vertical gridline is one task training period) allows comparison of whether a longer task training period leads to forgetting more:
- Each task time achieves similar test accuracy throughout the experiment
- The network does not forget more when it is trained longer on each task
- Therefore, the amount of learning (i.e. test accuracy) is relative to task switchover frequency rather than number of training steps
- Therefore network inherently performs some multi-task learning
Multi-task learning versus size of hidden layer
[This is older experiment setup, needs re-tested, may just redact this]
Let’s test a two layer fully-connected network on the same tasks to see how first layer width affects task forgetting.
- Narrower networks are more forgetful
- This is intuitive, since there is less room to remember each tasks’ specific computations, and therefore the training must over-write the previous task to remember the current task