MoDo¶
- class torchjd.aggregation.MoDoWeighting(gamma=0.1, rho=0.1)[source]¶
StatefulWeighting[Matrix] from Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance (JMLR 2024).Warning
The input matrix must be \(G = J_1 J_2^\top\), computed from two independent mini-batches via
torchjd.autojac.jac(). Using a single-batch Gramian (\(J_1 J_1^\top\)) breaks the convergence guarantee. See the usage examples below.- Parameters:
Note
The Euclidean projection onto the simplex used in the \(\lambda\) update is adapted from the official implementation.
Example (two batches per step)
The following example reproduces basic MoDo using two independent mini-batches per step. This matches MoDo as described in the paper, and the behavior of the official implementation when
three_gradsisFalse.import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD from torchjd.aggregation import MoDoWeighting from torchjd.autojac import jac # Generate data (8 batches of 16 examples of dim 5) for the sake of the example. inputs = torch.randn(8, 16, 5) targets = torch.randn(8, 16) model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) optimizer = SGD(model.parameters()) criterion = MSELoss(reduction="none") weighting = MoDoWeighting(gamma=0.1, rho=0.0) params = list(model.parameters()) # Consume two consecutive (independent) batches per step. for i in range(len(inputs) // 2): input_1, input_2 = inputs[2 * i], inputs[2 * i + 1] target_1, target_2 = targets[2 * i], targets[2 * i + 1] # retain_graph=True so both graphs survive for the backward step below. losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) jacs_1 = jac(losses_1, params, retain_graph=True) J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) jacs_2 = jac(losses_2, params, retain_graph=True) J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) G = J_1 @ J_2.T weights = weighting(G) # Equation 2.9b: the parameter update uses the mean of both batches' losses. losses = (losses_1 + losses_2) / 2.0 losses.backward(weights) optimizer.step() optimizer.zero_grad()
Example (three batches per step)
The following example reproduces basic MoDo using three independent mini-batches per step, keeping the \(\lambda\) update and the parameter update on separate draws. This matches the behavior of LibMTL and of the official implementation when
three_gradsisTrue.import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD from torchjd.aggregation import MoDoWeighting from torchjd.autojac import jac # Generate data (9 batches of 16 examples of dim 5) for the sake of the example. inputs = torch.randn(9, 16, 5) targets = torch.randn(9, 16) model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) optimizer = SGD(model.parameters()) criterion = MSELoss(reduction="none") weighting = MoDoWeighting(gamma=0.1, rho=0.0) params = list(model.parameters()) # Consume three consecutive (independent) batches per step. for i in range(len(inputs) // 3): input_1, input_2, input_3 = inputs[3 * i], inputs[3 * i + 1], inputs[3 * i + 2] target_1, target_2, target_3 = targets[3 * i], targets[3 * i + 1], targets[3 * i + 2] losses_1 = criterion(model(input_1).squeeze(dim=1), target_1) jacs_1 = jac(losses_1, params) J_1 = torch.cat([j.flatten(1) for j in jacs_1], dim=1) losses_2 = criterion(model(input_2).squeeze(dim=1), target_2) jacs_2 = jac(losses_2, params) J_2 = torch.cat([j.flatten(1) for j in jacs_2], dim=1) G = J_1 @ J_2.T weights = weighting(G) losses_3 = criterion(model(input_3).squeeze(dim=1), target_3) losses_3.backward(weights) optimizer.step() optimizer.zero_grad()