CAGrad

class torchjd.aggregation.CAGrad(c, norm_eps=0.0001)[source]

Aggregator as defined in Algorithm 1 of Conflict-Averse Gradient Descent for Multi-task Learning.

Parameters:
  • c (float) – The scale of the radius of the ball constraint.

  • norm_eps (float) – A small value to avoid division by zero when normalizing.

Example

Use CAGrad to aggregate a matrix.

>>> from torch import tensor
>>> from torchjd.aggregation import CAGrad
>>>
>>> A = CAGrad(c=0.5)
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([0.1835, 1.2041, 1.2041])

Note

This aggregator is not installed by default. When not installed, trying to import it should result in the following error: ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'. To install it, use pip install torchjd[cagrad].