ExcessMTL¶
- class torchjd.aggregation.ExcessMTL(robust_step_size=1.0, n_warmup_steps=0)[source]¶
StatefulWeightedAggregatorfrom Robust Multi-Task Learning with Excess Risks (ICML 2024).At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven by per-task excess risk estimates. See
ExcessMTLWeightingfor details on the algorithm and state management.- Parameters:
robust_step_size (
float) – Step size \(\eta_\alpha\) for the exponentiated weight update. Must be positive.n_warmup_steps (
int) – Number of forward calls during which weights stay uniform (\([1/m, \ldots, 1/m]\)) and gradient statistics are collected. The baseline excess risk is then set to the average excess risk observed during warmup. When0(default), the first call’s excess risk is used immediately as the baseline, matching the behavior of the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting statistics for 3 full epochs, i.e.n_warmup_steps = 3 * len(dataloader).
- class torchjd.aggregation.ExcessMTLWeighting(robust_step_size=1.0, n_warmup_steps=0)[source]¶
StatefulWeighting[Matrix] from Robust Multi-Task Learning with Excess Risks (ICML 2024).At each call, task weights are updated via an exponentiated gradient step (Equation 9) driven by per-task excess risk estimates. The excess risk for task \(i\) is approximated via a second-order Taylor expansion (Equations 6-7).
- Parameters:
robust_step_size (
float) – Step size \(\eta_\alpha\) for the exponentiated weight update. Must be positive.n_warmup_steps (
int) – Number of forward calls during which weights stay uniform (\([1/m, \ldots, 1/m]\)) and gradient statistics are collected. The baseline excess risk is then set to the average excess risk observed during warmup. When0(default), the first call’s excess risk is used immediately as the baseline, matching the behavior of the official implementation and LibMTL. The paper (Appendix C.1) recommends collecting statistics for 3 full epochs, i.e.n_warmup_steps = 3 * len(dataloader).
Warning
The state tensor \(S \in \mathbb{R}^{m \times n}\) accumulates squared gradients across calls, where \(n\) is the total number of model parameters. For large models this can be a significant memory cost. Call
reset()between experiments.Note
The weight update is adapted from the official implementation and LibMTL. Unlike those implementations, which initialize task weights to
1, we follow the paper and initialize them to1/mso that they always lie on the probability simplex.