Partial Jacobian Descent for IWRM

This example demonstrates how to perform Partial Jacobian Descent using TorchJD. This technique minimizes a vector of per-instance losses by resolving conflict only based on a submatrix of the Jacobian — specifically, the portion corresponding to a selected subset of the model’s parameters. This approach offers a trade-off between the precision of the aggregation decision and the computational cost associated with computing the Gramian of the full Jacobian. For a complete, non-partial version, see the IWRM example.

In this example, our model consists of three Linear layers separated by ReLU layers. We perform the partial descent by considering only the parameters of the last two Linear layers. By doing this, we avoid computing the Jacobian and its Gramian with respect to the parameters of the first Linear layer, thereby reducing memory usage and computation time.

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import UPGradWeighting
from torchjd.autogram import Engine

X = torch.randn(8, 16, 10)
Y = torch.randn(8, 16)

model = Sequential(Linear(10, 8), ReLU(), Linear(8, 5), ReLU(), Linear(5, 1))
loss_fn = MSELoss(reduction="none")

weighting = UPGradWeighting()

# Create the autogram engine that will compute the Gramian of the
# Jacobian with respect to the two last Linear layers' parameters.
engine = Engine(model[2:].modules(), batch_dim=0)

params = model.parameters()
optimizer = SGD(params, lr=0.1)

for x, y in zip(X, Y):
    y_hat = model(x).squeeze(dim=1)  # shape: [16]
    losses = loss_fn(y_hat, y)  # shape: [16]
    optimizer.zero_grad()
    gramian = engine.compute_gramian(losses)
    weights = weighting(gramian)
    losses.backward(weights)
    optimizer.step()