CAGrad

class torchjd.aggregation.CAGrad(c, norm_eps=0.0001)[source]

GramianWeightedAggregator as defined in Algorithm 1 of Conflict-Averse Gradient Descent for Multi-task Learning.

Parameters:
  • c (float) – The scale of the radius of the ball constraint.

  • norm_eps (float) – A small value to avoid division by zero when normalizing.

Note

This aggregator requires optional dependencies. When they are not installed, instantiating it raises an ImportError with installation instructions. To install them, use pip install "torchjd[cagrad]".

__call__(*args, **kwargs)[source]

Computes the aggregation from the input matrix and applies all registered hooks.

Parameters:

matrix – The Jacobian to aggregate.

Return type:

Any

class torchjd.aggregation.CAGradWeighting(c, norm_eps=0.0001)[source]
__call__(*args, **kwargs)[source]

Call self as a function.

Return type:

Any