PCGrad

class torchjd.aggregation.pcgrad.PCGrad

Aggregator as defined in algorithm 1 of Gradient Surgery for Multi-Task Learning.

Example

Use PCGrad to aggregate a matrix.

>>> from torch import tensor
>>> from torchjd.aggregation import PCGrad
>>>
>>> A = PCGrad()
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([0.5848, 3.8012, 3.8012])