Monitoring aggregationsΒΆ

The Aggregator class is a subclass of torch.nn.Module. This allows registering hooks, which can be used to monitor some information about aggregations. The following code example demonstrates registering a hook to compute and print the cosine similarity between the aggregation performed by UPGrad and the average of the gradients, and another hook to compute and print the weights of the weighting of UPGrad.

Updating the parameters of the model with the average gradient is equivalent to using gradient descent on the average of the losses. Observing a cosine similarity smaller than 1 means that Jacobian descent is doing something different than gradient descent. With UPGrad, this happens when the original gradients conflict (i.e. they have a negative inner product).

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.nn.functional import cosine_similarity
from torch.optim import SGD

from torchjd import mtl_backward
from torchjd.aggregation import UPGrad

def print_weights(_, __, weights: torch.Tensor) -> None:
    """Prints the extracted weights."""
    print(f"Weights: {weights}")

def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.Tensor) -> None:
    """Prints the cosine similarity between the aggregation and the average gradient."""
    matrix = inputs[0]
    gd_output = matrix.mean(dim=0)
    similarity = cosine_similarity(aggregation, gd_output, dim=0)
    print(f"Cosine similarity: {similarity.item():.4f}")

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
    *shared_module.parameters(),
    *task1_module.parameters(),
    *task2_module.parameters(),
]

loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()

aggregator.weighting.register_forward_hook(print_weights)
aggregator.register_forward_hook(print_gd_similarity)

inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
    features = shared_module(input)
    output1 = task1_module(features)
    output2 = task2_module(features)
    loss1 = loss_fn(output1, target1)
    loss2 = loss_fn(output2, target2)

    optimizer.zero_grad()
    mtl_backward(losses=[loss1, loss2], features=features, aggregator=aggregator)
    optimizer.step()