CAGrad¶
- class torchjd.aggregation.CAGrad(c, norm_eps=0.0001)[source]¶
Aggregator
as defined in Algorithm 1 of Conflict-Averse Gradient Descent for Multi-task Learning.- Parameters:
Note
This aggregator is not installed by default. When not installed, trying to import it should result in the following error:
ImportError: cannot import name 'CAGrad' from 'torchjd.aggregation'
. To install it, usepip install torchjd[cagrad]
.
- class torchjd.aggregation.CAGradWeighting(c, norm_eps=0.0001)[source]¶
Weighting
giving the weights ofCAGrad
.- Parameters:
Note
This implementation differs from the official implementations in the way the underlying optimization problem is solved. This uses the CLARABEL solver of cvxpy rather than the scipy.minimize function.