CAGrad¶
- class torchjd.aggregation.cagrad.CAGrad(c, norm_eps=0.0001)¶
Aggregator
as defined in Algorithm 1 of Conflict-Averse Gradient Descent for Multi-task Learning.- Parameters:
Example
Use CAGrad to aggregate a matrix.
>>> import warnings >>> warnings.filterwarnings("ignore") >>> >>> from torch import tensor >>> from torchjd.aggregation import CAGrad >>> >>> A = CAGrad(c=0.5) >>> J = tensor([[-4., 1., 1.], [6., 1., 1.]]) >>> >>> A(J) tensor([0.1835, 1.2041, 1.2041])