CAGrad¶
- class torchjd.aggregation.CAGrad(c, norm_eps=0.0001)[source]¶
GramianWeightedAggregatoras defined in Algorithm 1 of Conflict-Averse Gradient Descent for Multi-task Learning.- Parameters:
Note
This aggregator requires optional dependencies. When they are not installed, instantiating it raises an
ImportErrorwith installation instructions. To install them, usepip install "torchjd[cagrad]".