An easy recipe for multi-task learning in PyTorch that you can do at home

Adam K
GumGum Tech Blog
Published in
10 min readDec 27, 2022
An image of a robot chef, generated DALL-E mini, a model itself trained on multiple tasks.

In my last blogpost, I touched on what multi-task learning is, how it works, and why it is becoming the de facto standard for machine learning development. In this blogpost, I want to share a simple implementation of a multi-task learning model that you can experiment with yourself or adapt to whatever task (or tasks!) you’re interested in. I’ll show the example in PyTorch using the same natural language data as my last post, movie and Yelp reviews but the architecture I’m offering is agnostic and could work for images, tabular data or any other kind of data.

The target audience for this blogpost is someone who already knows the basics of neural networks and deep learning, as well as PyTorch; I’m not intending this to be a complete intro to those topics. Check out the official docs or this or that tutorial if you’re a PyTorch beginner or need a refresher. I’m also not aiming to have the most elegant or efficient code possible, but rather aiming for something that communicates what the code is doing, runnable pseudo code, if you will.

Let’s start by designing a simple PyTorch Dataset, which will handle the data’s loading, storing and preprocessing. This Dataset object is very simple; it takes in a Scipy sparse matrix for the input variable, in this case, sparse bags-of-words representations, and another Numpy array for the binary or one-hot encoded output variable. One note: since my input data is in a sparse format via an sk-learn CountVectorizer, I’m being a cool guy and converting to PyTorch Tensors on the fly to be more memory efficient. Other than that, this Dataset object is pretty standard and can easily be changed to handle other kinds of data.

class Task_Dataset(Dataset):
def __init__(self, X : sp.sparse.csr.csr_matrix,
y : np.ndarray):
self.X = X
self.y = torch.from_numpy(y).float()
assert self.X.shape[0] == self.y.shape[0]

def __len__(self):
return len(self.y)

def __getitem__(self, idx):
X = torch.from_numpy(self.X[idx].astype(np.int8).todense()).float().squeeze()
y = self.y[idx]
return X, y
movie_ds = Task_Dataset(movie_X_train, movie_y_train)
movie_dl = DataLoader(movie_ds, batch_size = 64, shuffle = True)

yelp_ds = Task_Dataset(yelp_X_train, yelp_y_train)
yelp_dl = DataLoader(yelp_ds, batch_size = 64, shuffle = True)

With the data object defined, let’s move to build the PyTorch Module for a single-task problem. I know you’re reading this to learn how to build artisanal multi-task models, but we need this single-task model to compare with the multi-task version later. For your own artisanal multi-task project, you can use whatever architecture you want. For demonstration, I’m sticking with a run-of-the-mill multi-layer perceptron, complete with a single hidden layer and a final layer. This final layer is important and will be the key to our multi-task architecture.

class SingleTask_Network(nn.Module):
def __init__(self, input_dim : int,
output_dim : int = 1,
hidden_dim : int = 300):
super(SingleTask_Network, self).__init__()

self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim

self.hidden = nn.Linear(self.input_dim, self.hidden_dim)
self.final = nn.Linear(self.hidden_dim, self.output_dim)

def forward(self, x : torch.Tensor):
x = self.hidden(x)
x = torch.sigmoid(x)
x = self.final(x)
return x

For those who are familiar with PyTorch, this is preaching to the choir, but I want to call out what the forward method is doing here since it will be relevant later. This method takes an input Tensor object and applies some functions and transformations to it before finally applying it to the final layer and returning that value. In this example, the transformations are very simple, but for your project, you can have whatever you need for the target task.

Jumping a bit ahead, the forward method will be nearly the same for the multi-task model, albeit with one small change. Before we get to that, let’s talk about how to train this model. To do so, we need to define a loss function and an optimizer. In this case, we’ll build this model for the movie dataset which is binary — is this movie review positive or negative — making the proper loss function binary cross-entropy, BCEWithLogitsLoss (for the nerds, why BCEWithLogitsLoss instead of BCELoss? Check out this). For an optimizer, we’ll go with Adam because it’s ̶m̶y̶ ̶n̶a̶m̶e̶ very robust and dependable.

The training loop is straightforward and looks like this:

model = SingleTask_Network(movie_ds.X.shape[1], movie_ds.y.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
loss_fn = nn.BCEWithLogitsLoss()
for i in range(6):
for j, (batch_X, batch_y) in enumerate(movie_dl):
preds = model(batch_X)
loss = loss_fn(preds, batch_y)

optimizer.zero_grad()
loss.backward()
optimizer.step()

Pretty simple. For each mini-batch, we make predictions and compute the loss value for that mini-batch. Then, we zero the gradients, do a little backprop, and then update the weights.

Moving to a single-task model for the Yelp data, training is the same with one tiny but important difference. I’ve broken the Yelp dataset into three labels: negative, neutral or positive. Because of this, we’ll need to use CrossEntropyLoss as our loss function since it’s now a multilabel problem. Other than changing the loss function, the training loop will look the same as with the movie data. Just so it’s clear, the two tasks are different kinds of problems, and as such have different output shapes — binary vs. multilabel — and our multi-task model will need to be able to handle that.

Below is a chart of the overall performance of either model across six epochs. The final F1 is .82 for movie reviews and .65 for Yelp. That doesn’t mean much now, but it will be helpful to compare when we train a single model on the two tasks simultaneously.

F1 on test data for each single-task model. Note that for Yelp, performance starts around .4 before reaching .65 at the end.

Let’s move on to talk about how we can build a multi-task model. I mentioned earlier that the final layer is the key. At a high level, the multi-task model architecture will look very similar to that of a single-task model with two main differences:

  1. The model will have multiple final layers, one for each task. The final layer will reflect the nature of the task, e.g., binary vs multi-label.
  2. The forward method will still apply a series of transformations to the input, but will take an additional argument, task_id, which determines which final layer to use. All tasks will share these penultimate transformations, and that’s where the magic of multi-task learning is.

It seems very simple, but that’s the beauty of PyTorch. You can really do a lot with relatively few code changes. Here’s what that looks like:

class MultiTask_Network(nn.Module):
def __init__(self, input_dim,
output_dim_0 : int = 1, output_dim_1 : int = 3,
hidden_dim : int = 200):

super(MultiTask_Network, self).__init__()
self.input_dim = input_dim
self.output_dim_0 = output_dim_0
self.output_dim_1 = output_dim_1
self.hidden_dim = hidden_dim

self.hidden = nn.Linear(self.input_dim, self.hidden_dim)
self.final_0 = nn.Linear(self.hidden_dim, self.output_dim_0)
self.final_1 = nn.Linear(self.hidden_dim, self.output_dim_1)

def forward(self, x : torch.Tensor, task_id : int):
x = self.hidden(x)
x = torch.sigmoid(x)
if task_id == 0:
x = self.final_0(x)
elif task_id == 1:
x = self.final_1(x)
else:
assert False, 'Bad Task ID passed'

return x

And that’s it, a few extra variables in the __init__ method and an if block in the forward method, but the flow of the forward method is pretty much the same: apply some transformations or other functions before applying a final layer. The main difference is that though all penultimate layers are shared, the final layer varies from task to task. As the PyTorch model object is concerned, we’re done. Pat yourself on the back! You now know how to build your very own multi-task deep learning model!

But, there’s one thing we need to consider: how we can train a network with this architecture? A naive approach would be to train each task separately, a complete epoch of one task after the other. However, that would likely run into the dramatically-named problem of catastrophic forgetting, where the model would immediately forget what it learned after training on one of the tasks. In other words, after training the model would have only learned to generalize data for the most recently trained task (check out this for more info). To offset this, it’s important to have a good training curriculum where the model is trained on data ordered in such a way that it best aids learning.

Now, exactly what the best curriculum for multi-task learning is an open topic (check out section 4.4 of this overview paper), but a generally successful strategy is to intersperse batches of each task using a sort of round-robin sequence. In other words, we train on a mini-batch of task 1, switch to one of task 2, then train on a different mini-batch for task 1, and so on until we’ve gone through an entire epoch’s data for each task.

The astute reader will ask what happens when the tasks’ datasets are different sizes. If that’s the case, you’ll likely want to either downsample or upsample the data from one task or the other. Which of those two strategies works best is an open question and the answer will vary based on your data and your tasks. If this applies to you and you have differently sized datasets, check out the CombinedDataLoader from PyTorch-Lighting. It will save you a lot of headaches (I know from personal experience).

For the sake of simplicity though, I’m going to assume the datasets are the same size because we have something else to worry about: computing the loss for this network. Each mini-batch for each task in this round-robin will have its own loss and the question is how to combine those so we know what updates to make during backpropagation. The answer is actually simple. Just sum the loss per task. The next steps are just like the single-task training loop where we zero the gradients, perform a backward pass using the summer loss and update the weights.

Again, the astute reader might pause. By summing the loss, won’t that cause larger updates than we’d get for just one task? And won’t some loss functions yield loss values of different ranges than others? Moreover, what if we wanted to weigh each task’s loss differently? I encourage the astute to check out this paper on just those topics, namely how to optimally weigh the loss for each task. In the mean time, summing the loss works just fine. At the end of the day, the loss is just a useful measure of how wrong our model is, and what matters is that each individual neuron has its own gradient and its own activation history for that mini-batch so we know how to dole out updates during backpropagation.

Let’s take a look at the code for the training loop of a multi-task model:

model = MultiTask_Network(movie_ds.X.shape[1], 
output_dim_0 = movie_ds.y.shape[1],
output_dim_1 = yelp_ds.y.shape[1])

optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
movie_loss_fn = nn.BCEWithLogitsLoss()
yelp_loss_fn = nn.CrossEntropyLoss()
for i in range(6):
zipped_dls = zip(movie_dl, yelp_dl)
for j, ((movie_batch_X, movie_batch_y), (yelp_batch_X, yelp_batch_y)) in enumerate(zipped_dls):

movie_preds = model(movie_batch_X, task_id = 0)
movie_loss = movie_loss_fn(movie_preds, movie_batch_y)

yelp_preds = model(yelp_batch_X, task_id = 1)
yelp_loss = yelp_loss_fn(yelp_preds, yelp_batch_y)

loss = movie_loss + yelp_loss
losses_per_epoch.append(loss.item())

optimizer.zero_grad()
loss.backward()
optimizer.step()

To touch on what’s different in this training section, we now have two DataLoader objects, so the first thing to do is zip them together a single iterator to get the mini-batches for each task in the main loop. Inside this loop is where the real fun is. We make predictions at each step and compute the loss for each task as if we were doing a single-task training steps for each task separately. We sum these loss values and that’s our final loss which will be used for the backward pass. Because addition is associative, it doesn’t matter which task is first or second. The next steps are just like the single-task training loop where we zero the gradients, perform that backward pass and update the weights.

Does this work? Does this model learn, and is it learning how we expect it to? Let’s look at some graphs!

The first graph is a simple one. We’ll look at the model’s performance over time. As expected, the performance for both tasks improves over time, showing that our implementation is learning both tasks simultaneously. In fact, the final F1 scores for movie reviews and Yelp data, respectively are .82 and .67, marginally better than their single-task comparisons. The real exciting thing to note, however, isn’t the final score, but that Yelp’s performance begins at .6 and increases slowly, whereas for the single-task model, it began close to .4. It seems that just a single mini-batch of one movie data, which is task 0, is able to lift the performance of the other!
(The interim results are computed after updating the weights for both single-task and multi-task models).

F1 score for both tasks over time. The multi-task model successfully learns to generalize for both tasks, albeit at different rates.

In this second graph, let’s look at the norm of the gradients for the three layers of our multi-task model: the hidden layer shared by both tasks and the task-specific final layers. Though this may seem redundant — we already know the model is learning from the previous chart — it helps confirm that the learning is happening for all layers; it’s not like only the hidden layer is updating its weights and the performance increase is just dumb, albeit statistically surprising, luck.

Gradient norm per mini-batch for the learnable parameters. For each layer, it’s greater than 0, showing that each layer is updating its weights per step.

Conclusion

So, there you go! A very simple implementation of multi-task learning in PyTorch. To summarize this method, we create separate final layers for each task in the model and then add a task_id argument to the forward method which tells the model which final layer to use. During training, we do a mini-batch of each task one after the other and sum them before backpropagating the loss. That’s it. Easy peasy.

Here is the notebook with the code behind these models and graphs. Have fun training your own multi-task model!

We’re always looking for new talent! View jobs.

Follow us: Facebook | Twitter | LinkedIn | Instagram

--

--