CAGrad

class torchjd.aggregation.CAGrad(c, norm_eps=0.0001)[source]

GramianWeightedAggregator as defined in Algorithm 1 of Conflict-Averse Gradient Descent for Multi-task Learning.

Parameters:
  • c (float) – The scale of the radius of the ball constraint.

  • norm_eps (float) – A small value to avoid division by zero when normalizing.

Note

This aggregator requires optional dependencies. When they are not installed, instantiating it raises an ImportError with installation instructions. To install them, use pip install "torchjd[cagrad]".

__call__(matrix, /)[source]

Computes the aggregation from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The Jacobian to aggregate.

Return type:

Tensor

class torchjd.aggregation.CAGradWeighting(c, norm_eps=0.0001)[source]
__call__(gramian, /)[source]

Computes the vector of weights from the input Gramian and applies all registered hooks.

Parameters:

gramian (Tensor) – The Gramian from which the weights must be extracted.

Return type:

Tensor