CAGrad¶
- class torchjd.aggregation.CAGrad(c, norm_eps=0.0001)[source]¶
Aggregator
as defined in Algorithm 1 of Conflict-Averse Gradient Descent for Multi-task Learning.- Parameters:
Example
Use CAGrad to aggregate a matrix.
>>> from torch import tensor >>> from torchjd.aggregation import CAGrad >>> >>> A = CAGrad(c=0.5) >>> J = tensor([[-4., 1., 1.], [6., 1., 1.]]) >>> >>> A(J) tensor([0.1835, 1.2041, 1.2041])
Note
This aggregator is not installed by default. When not installed, trying to import it should result in the following error:
ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'
. To install it, usepip install torchjd[cagrad]
.