Practical Multitask Learning: My First Steps

Ziv Nachum
Theator Tech
5 min readMar 3, 2024

--

Multitask learning is a pretty cool buzzword, but a rather complex thing to do. Optimizing a single monstrous model with multiple objectives sounds almost impossible. Even just getting there by building a proper framework, with a dataset that works with very different tasks, is not easy to achieve.

In my work as a computer vision engineer, I wanted to train a model that performs very different tasks on the same input — image classification, temporal video segmentation, video classification, and temporal action localization, while keeping the code robust and straightforward. I also wanted to finish the coding part quickly and start experimenting with the model and optimizing it.

In this post, I’ll explain how I built a working multitask learning architecture, and provide practical insights on how you can do it too.

What’s the problem?

Our model consists of two parts: a feature extractor (backbone) which creates a feature representation from the video input, and task-specific model heads, such as LSTM. These heads get feature representation created in the backbone as an input, and each of them performs a single task like classification, segmentation or detection.

A big question is how to train such a model.

One solution for this problem is to run each task separately and iteratively, i.e., given a data sample, pass it through a backbone, then pass the feature vector through the first task while keeping the other task heads frozen, compute the loss, backpropagate and update the model weights. Then do the same thing for the second task, and so on.

This solution was not good enough since, at each step, the shared part leaned more towards a specific task, rather than stepping into a direction that optimized all tasks at once.

In this post, I want to offer a different approach that focuses on training the backbone and all task heads. In order to achieve that, we had to build the dataset correctly and ensure it flowed through the different models as we wanted it to.

How we tried to slay the multi-headed dragon

In multitask learning, the model contains a single shared backbone, which extracts the same features for all tasks.

class MultitaskModel(torch.nn.Module):
def __init__(self, shared: nn.Module, tasks: List[Task]):
super().__init__()
self.shared = shared
self.tasks = tasks
self.tasks_heads = torch.nn.ModuleDict(
{task.name: task.generate_task_model() for task in tasks})

Our first design choice was to define separate independent tasks. Each task defined a model head, a loss function and metrics. In addition, for each sample from our data, each task defined how to handle input. For example, one task may have reshaped the samples in order to fit them correctly into its model, and another task might have augmented the data before passing it through the model.

class Task(ABC):
name = "task"

def __init__(self, active: bool, loss_weight: float = 1.0):
self.active = active
self.loss_weight = loss_weight

@abstractmethod
def get_label(self, sample: Any, training: bool) -> Any:
pass

@abstractmethod
def generate_task_model(self) -> nn.Module:
pass

@abstractmethod
def inputs_adapter(self, inputs: Any) -> Any:
pass

@abstractmethod
def labels_adapter(self, labels: Any) -> Any:
pass

@property
@abstractmethod
def loss(self) -> Union[nn.Module, Callable]:
pass

@property
@abstractmethod
def get_pred(self) -> torch.Tensor:
pass
A general overview of the described multitask architecture

Every sample corresponded to a subset of tasks. When initializing the dataset, we generated samples and labels for every task, and kept track of which sample corresponded to which tasks.

Performing forward() on the model should have been fairly straightforward. Given an input, we ran it through the shared feature extractor, then iterated through the different tasks, modified the feature vector to fit the relevant task head, and then ran the relevant model head. Then all that was left to do was calculate the loss for this batch, which was defined as the weighted sum for each task's loss.

The loss is the weighted sum of each task’s loss.
  def forward(self, inputs):
shared_features = self.shared(inputs)

outputs_dict = {}
preds_dict = {}
for task in self.tasks:
task_inputs = task.inputs_adapter(shared_features)
outputs_dict[task.name] = self.tasks_heads[task.name](*task_inputs)
preds_dict[task.name] = task.get_pred(outputs_dict[task.name])

return outputs_dict, preds_dict

def backward(self, outputs, labels):
loss = torch.tensor(1)
for task in self.tasks:
curr_labels = task.labels_adapter(labels)
loss += task.loss_weight * task.loss(outputs[task.name], curr_labels)
return loss

When training a model to perform a backward pass, the loss function must be computed by performing a forward pass through all the model parameters. This might have caused an issue if there was a batch with no sample for one task. Technically, we could have just set the loss to zero for this task, but this solution might have caused unwanted behavior, since, in some iterations, we would get losses for all tasks, and in some cases, we wouldn’t.

A solution for this problem was to define a sampler that ensures, in every batch, that there was at least one sample that corresponded to each task. When creating the dataset, we had the information of which sample was relevant to which of the tasks. Another important consideration in implementing this sampler was making sure that each of the tasks saw as much data as possible in each task, but also saw a very diverse set of data.

Each batch contains a few samples which must correspond to all participating tasks.

What the future holds

Once this architecture was designed and implemented, exploring and optimizing this complex model was very easy. Even debugging the whole system was quite simple. The losses and metrics of each task were calculated and reported, which made optimizing the learning policy possible.

As we move forward, we’re going to face the challenge of optimizing the model for all tasks. This can be tackled using several approaches that require exploration, such as balancing the losses by adjusting their weights, and applying a different learning rate for each part of the model — either model head.

Training a multitask system probably won’t get each task model to the optimum due to the tension between the different tasks. On the other hand, our backbone network might benefit from the diversity of tasks. In our case, it allowed us not only to enrich our feature representation and shorten our training cycles, but to use a single heavy backbone model (instead of multiple models), which saved us a lot of inference time and memory.

--

--