PCGrad¶
- class torchjd.aggregation.pcgrad.PCGrad¶
Aggregator
as defined in algorithm 1 of Gradient Surgery for Multi-Task Learning.Example
Use PCGrad to aggregate a matrix.
>>> from torch import tensor >>> from torchjd.aggregation import PCGrad >>> >>> A = PCGrad() >>> J = tensor([[-4., 1., 1.], [6., 1., 1.]]) >>> >>> A(J) tensor([0.5848, 3.8012, 3.8012])