Aligned-MTL

class torchjd.aggregation.aligned_mtl.AlignedMTL(pref_vector=None)

Aggregator as defined in Algorithm 1 of Independent Component Alignment for Multi-Task Learning.

Parameters:

pref_vector (Tensor | None) – The preference vector to use.

Example

Use AlignedMTL to aggregate a matrix.

>>> from torch import tensor
>>> from torchjd.aggregation import AlignedMTL
>>>
>>> A = AlignedMTL()
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([0.2133, 0.9673, 0.9673])

Note

This implementation was adapted from the official implementation.