CR-MOGM¶
- class torchjd.aggregation.CRMOGMWeighting(weighting, alpha=0.9, initial_weights=None)[source]¶
StatefulWeightingthat wraps anotherWeightingand stabilises the weights it produces with an exponential moving average (EMA) across calls. This is the weight-smoothing modifier from On the Convergence of Stochastic Multi-Objective Gradient Manipulation and Beyond (NeurIPS 2022).Let \(\hat{\lambda}_k\) be the weights returned by the wrapped weighting at step \(k\). The smoothed weights returned by
CRMOGMWeightingare:\[\lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k\]where \(\lambda_0\) is
initial_weightsif provided, otherwise \(\lambda_0 = \hat{\lambda}_1\) (so that the first smoothed output equals \(\hat{\lambda}_1\) regardless of \(\alpha\)).Creating the corresponding
Aggregatorfrom a wrapped weighting can be done by composing it with the appropriate aggregator subclass (WeightedAggregatororGramianWeightedAggregator)The following example shows how to instantiate a Gramian-based weighted aggregator whose Gramian weighting is wrapped by CR-MOGM.
from torchjd.aggregation import CRMOGMWeighting, GramianWeightedAggregator, UPGradWeighting aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting()))
The following example shows how to instantiate a Matrix-based weighted aggregator whose weighting is wrapped by CR-MOGM.
from torchjd.aggregation import CRMOGMWeighting, MeanWeighting, WeightedAggregator aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting()))
Note that here,
MeanWeightingis used just for the sake of the example: the exponential moving average of constant weights will always be equal to the weights themselves, so wrapping byCRMOGMWeightingwill have no effect.This weighting is stateful: it keeps \(\lambda_{k-1}\) across calls. Use
reset()to restart the smoothing from the initial state. Note that callingreset()will also reset the wrapped weighting if it isStateful.- Parameters:
weighting (
Weighting[TypeVar(_T, bound=Tensor, contravariant=True)]) – The wrapped weighting whose output is smoothed.alpha (
float) – EMA coefficient on the previous weights.alpha=0disables smoothing (CRMOGMWeightingreturnsweighting’s output verbatim) andalpha=1freezes the weights at their initial value. The default of0.9follows the usual EMA convention (analogous to Adam’s \(\beta_1\)).initial_weights (
Tensor|None) – Optional tensor to use as \(\lambda_0\). IfNone(default), \(\lambda_0\) is set to \(\hat{\lambda}_1\) on the first forward call, making the first smoothed output equal to \(\hat{\lambda}_1\).
Note
alphais a fixedfloatfor simplicity. Corollary 1 of the paper recommends a schedule where \(\alpha_k\) starts near 0 and increases toward 1 as the learning rate decays. Updatealphabetween forward calls via the setter.The following example shows how to update alpha with the suggested scheme from the paper, when the aggregator is a Gramian-based weighted aggregator whose Gramian weighting is wrapped by CR-MOGM:
from torchjd.aggregation import ( CRMOGMWeighting, GramianWeightedAggregator, UPGradWeighting, ) aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) initial_lr = 0.1 current_lr = 0.05 # e.g. obtained from lr_scheduler.get_lr()[0] cr_mogm = aggregator.gramian_weighting cr_mogm.alpha = 1 - current_lr / initial_lr
Note
The usage example in the docstring above imports
WeightedAggregator / GramianWeightedAggregator from
torchjd.aggregation._aggregator_bases, which is a private module. These two
aggregator base classes are not currently part of the public torchjd.aggregation
namespace, so this private-module import is the only path that works today. Promoting
them to the public namespace is a separate decision left to the maintainers.