SDMGrad

class torchjd.aggregation.SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=20, lambda_=0.3, pref_vector=None)[source]

Stateful Weighting [Matrix] from Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms (NeurIPS 2023).

Warning

The input matrix must be \(A = J_1 J_2^\top\), computed from two independent mini-batches via torchjd.autojac.jac(). It is not a Gramian and is not symmetric or positive semi-definite in general. See the usage examples below.

Parameters:
  • lr (float) – Learning rate of the inner SGD that solves for the task weights. Must be positive.

  • momentum (float) – Momentum of the inner SGD. Must be in \([0, 1)\).

  • n_iter (int) – Number of inner SGD iterations performed at each call. Must be positive.

  • lambda_ (float) – Non-negative coefficient controlling how strongly the descent direction is pulled toward the preference direction. Must be non-negative.

  • pref_vector (Tensor | None) – The preference vector \(\tilde w\) defining the target direction. If not provided, defaults to the uniform vector \([1/m, \ldots, 1/m]\) (i.e. the target diection is the average gradient).

Note

The inner simplex-projected solver is adapted from the official implementation. Note that the official class default for this coefficient is 0.6, overridden to 0.3 in their own experiments, which is the value used here (and in LibMTL).

Before the inner solve, the input matrix is scale-normalized by the mean of the square roots of its non-negative diagonal entries (following both the official implementation and LibMTL). This makes the inner SGD learning rate scale-invariant with respect to gradient magnitude. The normalization is briefly described in section 6.1 of the paper.

Example (three batches per step)

The following example shows how to train with the SDMGrad algorithm.

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import SDMGradWeighting
from torchjd.autojac import jac

# Generate data (9 batches of 16 examples of dim 5) for the sake of the example.
inputs = torch.randn(9, 16, 5)
targets = torch.randn(9, 16)

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
optimizer = SGD(model.parameters())
criterion = MSELoss(reduction="none")
weighting = SDMGradWeighting(lambda_=0.3)
params = list(model.parameters())

# Consume three consecutive (independent) batches per step.
for i in range(len(inputs) // 3):
    # Batches corresponding to ξ, ξ' and ζ in the paper's algorithm.
    input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2]
    target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2]

    losses_1 = criterion(model(input_1).squeeze(dim=1), target_1)
    jacs_1 = jac(losses_1, params)
    J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1)

    losses_2 = criterion(model(input_2).squeeze(dim=1), target_2)
    jacs_2 = jac(losses_2, params)
    J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1)

    A = J_1 @ J_2.T
    weights = weighting(A)

    losses_3 = criterion(model(input_3).squeeze(dim=1), target_3)
    losses_3.backward(weights)
    optimizer.step()
    optimizer.zero_grad()
__call__(matrix, /)[source]

Computes the vector of weights from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The matrix from which the weights must be extracted.

Return type:

Tensor

reset()[source]

Clears the stored task weights so the next forward starts from uniform.

Return type:

None