MoDo

class torchjd.aggregation.MoDoWeighting(gamma=0.1, rho=0.1)[source]

Stateful Weighting [Matrix] from Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance (JMLR 2024).

Warning

The input matrix must be \(G = J_1 J_2^\top\), computed from two independent mini-batches via torchjd.autojac.jac(). Using a single-batch Gramian (\(J_1 J_1^\top\)) breaks the convergence guarantee. See the usage examples below.

Parameters:
  • gamma (float) – Learning rate of the task-weight update. Must be positive.

  • rho (float) – Non-negative \(\ell_2\) regularisation coefficient.

Note

The Euclidean projection onto the simplex used in the \(\lambda\) update is adapted from the official implementation.

Example (two batches per step)

The following example reproduces basic MoDo using two independent mini-batches per step. This matches MoDo as described in the paper, and the behavior of the official implementation when three_grads is False.

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import MoDoWeighting
from torchjd.autojac import jac

# Generate data (8 batches of 16 examples of dim 5) for the sake of the example.
inputs = torch.randn(8, 16, 5)
targets = torch.randn(8, 16)

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
optimizer = SGD(model.parameters())
criterion = MSELoss(reduction="none")
weighting = MoDoWeighting(gamma=0.1, rho=0.0)
params = list(model.parameters())

# Consume two consecutive (independent) batches per step.
for i in range(len(inputs) // 2):
    input_1, input_2 = inputs[2 * i], inputs[2 * i + 1]
    target_1, target_2 = targets[2 * i], targets[2 * i + 1]

    # retain_graph=True so both graphs survive for the backward step below.
    losses_1 = criterion(model(input_1).squeeze(dim=1), target_1)
    jacs_1 = jac(losses_1, params, retain_graph=True)
    J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1)

    losses_2 = criterion(model(input_2).squeeze(dim=1), target_2)
    jacs_2 = jac(losses_2, params, retain_graph=True)
    J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1)

    G = J_1 @ J_2.T
    weights = weighting(G)

    # Equation 2.9b: the parameter update uses the mean of both batches' losses.
    losses = (losses_1 + losses_2) / 2.0
    losses.backward(weights)
    optimizer.step()
    optimizer.zero_grad()

Example (three batches per step)

The following example reproduces basic MoDo using three independent mini-batches per step, keeping the \(\lambda\) update and the parameter update on separate draws. This matches the behavior of LibMTL and of the official implementation when three_grads is True.

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import MoDoWeighting
from torchjd.autojac import jac

# Generate data (9 batches of 16 examples of dim 5) for the sake of the example.
inputs = torch.randn(9, 16, 5)
targets = torch.randn(9, 16)

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
optimizer = SGD(model.parameters())
criterion = MSELoss(reduction="none")
weighting = MoDoWeighting(gamma=0.1, rho=0.0)
params = list(model.parameters())

# Consume three consecutive (independent) batches per step.
for i in range(len(inputs) // 3):
    input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2]
    target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2]

    losses_1 = criterion(model(input_1).squeeze(dim=1), target_1)
    jacs_1 = jac(losses_1, params)
    J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1)

    losses_2 = criterion(model(input_2).squeeze(dim=1), target_2)
    jacs_2 = jac(losses_2, params)
    J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1)

    G = J_1 @ J_2.T
    weights = weighting(G)

    losses_3 = criterion(model(input_3).squeeze(dim=1), target_3)
    losses_3.backward(weights)
    optimizer.step()
    optimizer.zero_grad()
reset()[source]

Clears the stored task weights so the next forward starts from uniform.

Return type:

None

__call__(matrix, /)[source]

Computes the vector of weights from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The matrix from which the weights must be extracted.

Return type:

Tensor