CAGrad

class torchjd.aggregation.CAGrad(c, norm_eps=0.0001)[source]

Aggregator as defined in Algorithm 1 of Conflict-Averse Gradient Descent for Multi-task Learning.

Parameters:
  • c (float) – The scale of the radius of the ball constraint.

  • norm_eps (float) – A small value to avoid division by zero when normalizing.

Note

This aggregator is not installed by default. When not installed, trying to import it should result in the following error: ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'. To install it, use pip install torchjd[cagrad].

class torchjd.aggregation.CAGradWeighting(c, norm_eps=0.0001)[source]

Weighting giving the weights of CAGrad.

Parameters:
  • c (float) – The scale of the radius of the ball constraint.

  • norm_eps (float) – A small value to avoid division by zero when normalizing.

Note

This implementation differs from the official implementations in the way the underlying optimization problem is solved. This uses the CLARABEL solver of cvxpy rather than the scipy.minimize function.