CR-MOGM

class torchjd.aggregation.CRMOGMWeighting(weighting, alpha=0.9, initial_weights=None)[source]

Stateful Weighting that wraps another Weighting and stabilises the weights it produces with an exponential moving average (EMA) across calls. This is the weight-smoothing modifier from On the Convergence of Stochastic Multi-Objective Gradient Manipulation and Beyond (NeurIPS 2022).

Let \(\hat{\lambda}_k\) be the weights returned by the wrapped weighting at step \(k\). The smoothed weights returned by CRMOGMWeighting are:

\[\lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k\]

where \(\lambda_0\) is initial_weights if provided, otherwise \(\lambda_0 = \hat{\lambda}_1\) (so that the first smoothed output equals \(\hat{\lambda}_1\) regardless of \(\alpha\)).

Creating the corresponding Aggregator from a wrapped weighting can be done by composing it with the appropriate aggregator subclass (WeightedAggregator or GramianWeightedAggregator)

The following example shows how to instantiate a Gramian-based weighted aggregator whose Gramian weighting is wrapped by CR-MOGM.

from torchjd.aggregation import CRMOGMWeighting, GramianWeightedAggregator, UPGradWeighting

aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting()))

The following example shows how to instantiate a Matrix-based weighted aggregator whose weighting is wrapped by CR-MOGM.

from torchjd.aggregation import CRMOGMWeighting, MeanWeighting, WeightedAggregator

aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting()))

Note that here, MeanWeighting is used just for the sake of the example: the exponential moving average of constant weights will always be equal to the weights themselves, so wrapping by CRMOGMWeighting will have no effect.

This weighting is stateful: it keeps \(\lambda_{k-1}\) across calls. Use reset() to restart the smoothing from the initial state. Note that calling reset() will also reset the wrapped weighting if it is Stateful.

Parameters:
  • weighting (Weighting[TypeVar(_T, bound= Tensor, contravariant=True)]) – The wrapped weighting whose output is smoothed.

  • alpha (float) – EMA coefficient on the previous weights. alpha=0 disables smoothing (CRMOGMWeighting returns weighting’s output verbatim) and alpha=1 freezes the weights at their initial value. The default of 0.9 follows the usual EMA convention (analogous to Adam’s \(\beta_1\)).

  • initial_weights (Tensor | None) – Optional tensor to use as \(\lambda_0\). If None (default), \(\lambda_0\) is set to \(\hat{\lambda}_1\) on the first forward call, making the first smoothed output equal to \(\hat{\lambda}_1\).

Note

alpha is a fixed float for simplicity. Corollary 1 of the paper recommends a schedule where \(\alpha_k\) starts near 0 and increases toward 1 as the learning rate decays. Update alpha between forward calls via the setter.

The following example shows how to update alpha with the suggested scheme from the paper, when the aggregator is a Gramian-based weighted aggregator whose Gramian weighting is wrapped by CR-MOGM:

from torchjd.aggregation import (
    CRMOGMWeighting,
    GramianWeightedAggregator,
    UPGradWeighting,
)

aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting()))

initial_lr = 0.1
current_lr = 0.05  # e.g. obtained from lr_scheduler.get_lr()[0]

cr_mogm = aggregator.gramian_weighting
cr_mogm.alpha = 1 - current_lr / initial_lr
reset()[source]

Clears the EMA state so the next forward restarts from the initial state. Also resets the wrapped weighting if it is Stateful.

Return type:

None

__call__(stat, /)[source]

Computes the vector of weights from the input stat and applies all registered hooks.

Parameters:

stat (Tensor) – The stat from which the weights must be extracted.

Return type:

Tensor

Note

The usage example in the docstring above imports WeightedAggregator / GramianWeightedAggregator from torchjd.aggregation._aggregator_bases, which is a private module. These two aggregator base classes are not currently part of the public torchjd.aggregation namespace, so this private-module import is the only path that works today. Promoting them to the public namespace is a separate decision left to the maintainers.