Aligned-MTL¶
- class torchjd.aggregation.aligned_mtl.AlignedMTL(pref_vector=None)¶
Aggregator
as defined in Algorithm 1 of Independent Component Alignment for Multi-Task Learning.Example
Use AlignedMTL to aggregate a matrix.
>>> from torch import tensor >>> from torchjd.aggregation import AlignedMTL >>> >>> A = AlignedMTL() >>> J = tensor([[-4., 1., 1.], [6., 1., 1.]]) >>> >>> A(J) tensor([0.2133, 0.9673, 0.9673])
Note
This implementation was adapted from the official implementation.