GumGum Tech Blog
Published in

GumGum Tech Blog

An easy recipe for multi-task learning in PyTorch that you can do at home

An image of a robot chef, generated DALL-E mini, a model itself trained on multiple tasks.
class Task_Dataset(Dataset):
def __init__(self, X : sp.sparse.csr.csr_matrix,
y : np.ndarray):
self.X = X
self.y = torch.from_numpy(y).float()
assert self.X.shape[0] == self.y.shape[0]

def __len__(self):
return len(self.y)

def __getitem__(self, idx):
X = torch.from_numpy(self.X[idx].astype(np.int8).todense()).float().squeeze()
y = self.y[idx]
return X, y
movie_ds = Task_Dataset(movie_X_train, movie_y_train)
movie_dl = DataLoader(movie_ds, batch_size = 64, shuffle = True)

yelp_ds = Task_Dataset(yelp_X_train, yelp_y_train)
yelp_dl = DataLoader(yelp_ds, batch_size = 64, shuffle = True)
class SingleTask_Network(nn.Module):
def __init__(self, input_dim : int,
output_dim : int = 1,
hidden_dim : int = 300):
super(SingleTask_Network, self).__init__()

self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim

self.hidden = nn.Linear(self.input_dim, self.hidden_dim)
self.final = nn.Linear(self.hidden_dim, self.output_dim)

def forward(self, x : torch.Tensor):
x = self.hidden(x)
x = torch.sigmoid(x)
x = self.final(x)
return x
model = SingleTask_Network(movie_ds.X.shape[1], movie_ds.y.shape[1])
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
loss_fn = nn.BCEWithLogitsLoss()
for i in range(6):
for j, (batch_X, batch_y) in enumerate(movie_dl):
preds = model(batch_X)
loss = loss_fn(preds, batch_y)

optimizer.zero_grad()
loss.backward()
optimizer.step()
F1 on test data for each single-task model. Note that for Yelp, performance starts around .4 before reaching .65 at the end.
  1. The model will have multiple final layers, one for each task. The final layer will reflect the nature of the task, e.g., binary vs multi-label.
  2. The forward method will still apply a series of transformations to the input, but will take an additional argument, task_id, which determines which final layer to use. All tasks will share these penultimate transformations, and that’s where the magic of multi-task learning is.
class MultiTask_Network(nn.Module):
def __init__(self, input_dim,
output_dim_0 : int = 1, output_dim_1 : int = 3,
hidden_dim : int = 200):

super(MultiTask_Network, self).__init__()
self.input_dim = input_dim
self.output_dim_0 = output_dim_0
self.output_dim_1 = output_dim_1
self.hidden_dim = hidden_dim

self.hidden = nn.Linear(self.input_dim, self.hidden_dim)
self.final_0 = nn.Linear(self.hidden_dim, self.output_dim_0)
self.final_1 = nn.Linear(self.hidden_dim, self.output_dim_1)

def forward(self, x : torch.Tensor, task_id : int):
x = self.hidden(x)
x = torch.sigmoid(x)
if task_id == 0:
x = self.final_0(x)
elif task_id == 1:
x = self.final_1(x)
else:
assert False, 'Bad Task ID passed'

return x
model = MultiTask_Network(movie_ds.X.shape[1], 
output_dim_0 = movie_ds.y.shape[1],
output_dim_1 = yelp_ds.y.shape[1])

optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
movie_loss_fn = nn.BCEWithLogitsLoss()
yelp_loss_fn = nn.CrossEntropyLoss()
for i in range(6):
zipped_dls = zip(movie_dl, yelp_dl)
for j, ((movie_batch_X, movie_batch_y), (yelp_batch_X, yelp_batch_y)) in enumerate(zipped_dls):

movie_preds = model(movie_batch_X, task_id = 0)
movie_loss = movie_loss_fn(movie_preds, movie_batch_y)

yelp_preds = model(yelp_batch_X, task_id = 1)
yelp_loss = yelp_loss_fn(yelp_preds, yelp_batch_y)

loss = movie_loss + yelp_loss
losses_per_epoch.append(loss.item())

optimizer.zero_grad()
loss.backward()
optimizer.step()
F1 score for both tasks over time. The multi-task model successfully learns to generalize for both tasks, albeit at different rates.
Gradient norm per mini-batch for the learnable parameters. For each layer, it’s greater than 0, showing that each layer is updating its weights per step.

Conclusion

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store