Multi-Task Learning (MTL)¶
In the context of multi-task learning, multiple tasks are performed simultaneously on a common input. Typically, a feature extractor is applied to the input to obtain a shared representation, useful for all tasks. Then, task-specific heads are applied to these features to obtain each task’s result. A loss can then be computed for each task. Fundamentally, multi-task learning is a multi-objective optimization problem in which we minimize the vector of task losses.
A common trick to train multi-task models is to cast the problem as single-objective, by minimizing a weighted sum of the losses. This works well in some cases, but sometimes conflict among tasks can make the optimization of the shared parameters very hard. Besides, the weight associated to each loss can be considered as a hyper-parameter. Finding their optimal value is generally expensive.
Alternatively, the vector of losses can be directly minimized using Jacobian descent. The following example shows how to use TorchJD to train a very simple multi-task model with two regression tasks. For the sake of the example, we generate a fake dataset consisting of 8 batches of 16 random input vectors of dimension 10, and their corresponding scalar labels for both tasks.
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd import mtl_backward
from torchjd.aggregation import UPGrad
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
A = UPGrad()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
optimizer.zero_grad()
mtl_backward(
losses=[loss1, loss2],
features=features,
tasks_params=[task1_module.parameters(), task2_module.parameters()],
shared_params=shared_module.parameters(),
A=A,
)
optimizer.step()
Note
In this example, the Jacobian is only with respect to the shared parameters. The task-specific parameters are simply updated via the gradient of their task’s loss with respect to them.