SDMGrad¶
- class torchjd.aggregation.SDMGradWeighting(lr=10.0, momentum=0.5, n_iter=20, lambda_=0.3, pref_vector=None)[source]¶
StatefulWeighting[Matrix] from Direction-oriented Multi-objective Learning: Simple and Provable Stochastic Algorithms (NeurIPS 2023).Warning
The input matrix must be \(A = J_1 J_2^\top\), computed from two independent mini-batches via
torchjd.autojac.jac(). It is not a Gramian and is not symmetric or positive semi-definite in general. See the usage examples below.- Parameters:
lr (
float) – Learning rate of the inner SGD that solves for the task weights. Must be positive.momentum (
float) – Momentum of the inner SGD. Must be in \([0, 1)\).n_iter (
int) – Number of inner SGD iterations performed at each call. Must be positive.lambda_ (
float) – Non-negative coefficient controlling how strongly the descent direction is pulled toward the preference direction. Must be non-negative.pref_vector (
Tensor|None) – The preference vector \(\tilde w\) defining the target direction. If not provided, defaults to the uniform vector \([1/m, \ldots, 1/m]\) (i.e. the target diection is the average gradient).
Note
The inner simplex-projected solver is adapted from the official implementation. Note that the official class default for this coefficient is
0.6, overridden to0.3in their own experiments, which is the value used here (and in LibMTL).Before the inner solve, the input matrix is scale-normalized by the mean of the square roots of its non-negative diagonal entries (following both the official implementation and LibMTL). This makes the inner SGD learning rate scale-invariant with respect to gradient magnitude. The normalization is briefly described in section 6.1 of the paper.
Example (three batches per step)
The following example shows how to train with the SDMGrad algorithm.
import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD from torchjd.aggregation import SDMGradWeighting from torchjd.autojac import jac # Generate data (9 batches of 16 examples of dim 5) for the sake of the example. inputs = torch.randn(9, 16, 5) targets = torch.randn(9, 16) model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) optimizer = SGD(model.parameters()) criterion = MSELoss(reduction="none") weighting = SDMGradWeighting(lambda_=0.3) params = list(model.parameters()) # Consume three consecutive (independent) batches per step. for i in range(len(inputs) // 3): # Batches corresponding to ξ, ξ' and ζ in the paper's algorithm. input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2] target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2] losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) jacs_1 = jac(losses_1, params) J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) jacs_2 = jac(losses_2, params) J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) A = J_1 @ J_2.T weights = weighting(A) losses_3 = criterion(model(input_3).squeeze(dim=1), target_3) losses_3.backward(weights) optimizer.step() optimizer.zero_grad()