IMTL-G

class torchjd.aggregation.imtl_g.IMTLG

Aggregator generalizing the method described in Towards Impartial Multi-task Learning. This generalization supports matrices with some linearly dependant rows.

Example

Use IMTL-G to aggregate a matrix.

>>> from torch import tensor
>>> from torchjd.aggregation import IMTLG
>>>
>>> A = IMTLG()
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([0.0767, 1.0000, 1.0000])