CAGrad

class torchjd.aggregation.cagrad.CAGrad(c, norm_eps=0.0001)

Aggregator as defined in Algorithm 1 of Conflict-Averse Gradient Descent for Multi-task Learning.

Parameters:
  • c (float) – The scale of the radius of the ball constraint.

  • norm_eps (float) – A small value to avoid division by zero when normalizing.

Example

Use CAGrad to aggregate a matrix.

>>> import warnings
>>> warnings.filterwarnings("ignore")
>>>
>>> from torch import tensor
>>> from torchjd.aggregation import CAGrad
>>>
>>> A = CAGrad(c=0.5)
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([0.1835, 1.2041, 1.2041])