ExcessMTL

class torchjd.aggregation.ExcessMTL(robust_step_size=1.0, n_warmup_steps=0)[source]

Stateful WeightedAggregator from Robust Multi-Task Learning with Excess Risks (ICML 2024).

At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven by per-task excess risk estimates. See ExcessMTLWeighting for details on the algorithm and state management.

Parameters:
  • robust_step_size (float) – Step size \(\eta_\alpha\) for the exponentiated weight update. Must be positive.

  • n_warmup_steps (int) – Number of forward calls during which weights stay uniform (\([1/m, \ldots, 1/m]\)) and gradient statistics are collected. The baseline excess risk is then set to the average excess risk observed during warmup. When 0 (default), the first call’s excess risk is used immediately as the baseline, matching the behavior of the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting statistics for 3 full epochs, i.e. n_warmup_steps = 3 * len(dataloader).

__call__(matrix, /)[source]

Computes the aggregation from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The Jacobian to aggregate.

Return type:

Tensor

reset()[source]

Clears all state so the next forward starts from uniform weights and re-enters warmup.

Return type:

None

class torchjd.aggregation.ExcessMTLWeighting(robust_step_size=1.0, n_warmup_steps=0)[source]

Stateful Weighting [Matrix] from Robust Multi-Task Learning with Excess Risks (ICML 2024).

At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven by per-task excess risk estimates. The excess risk for task \(i\) is approximated via a second-order Taylor expansion (Equations 6-7).

Parameters:
  • robust_step_size (float) – Step size \(\eta_\alpha\) for the exponentiated weight update. Must be positive.

  • n_warmup_steps (int) – Number of forward calls during which weights stay uniform (\([1/m, \ldots, 1/m]\)) and gradient statistics are collected. The baseline excess risk is then set to the average excess risk observed during warmup. When 0 (default), the first call’s excess risk is used immediately as the baseline, matching the behavior of the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting statistics for 3 full epochs, i.e. n_warmup_steps = 3 * len(dataloader).

Warning

The state tensor \(S \in \mathbb{R}^{m \times n}\) accumulates squared gradients across calls, where \(n\) is the total number of model parameters. For large models this can be a significant memory cost. Call reset() between experiments.

Note

The weight update is adapted from the official implementation and LibMTL. Unlike those implementations, which initialize task weights to 1, we follow the paper and initialize them to 1/m so that they always lie on the probability simplex.

reset()[source]

Clears all state so the next forward starts from uniform weights and re-enters warmup.

Return type:

None

__call__(matrix, /)[source]

Computes the vector of weights from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The matrix from which the weights must be extracted.

Return type:

Tensor