Aligned-MTL¶
- class torchjd.aggregation.AlignedMTL(pref_vector=None, scale_mode='min')[source]¶
Aggregatoras defined in Algorithm 1 of Independent Component Alignment for Multi-Task Learning.- Parameters:
pref_vector (
Tensor|None) – The preference vector to use. If not provided, defaults to \(\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m\).scale_mode (
Literal['min','median','rmse']) – The scaling mode used to build the balance transformation."min"uses the smallest eigenvalue (default),"median"uses the median eigenvalue, and"rmse"uses the mean eigenvalue (as in the original implementation).
Note
This implementation was adapted from the official implementation of SamsungLabs/MTL, which is not available anymore at the time of writing.
- class torchjd.aggregation.AlignedMTLWeighting(pref_vector=None, scale_mode='min')[source]¶
Weightinggiving the weights ofAlignedMTL.- Parameters:
pref_vector (
Tensor|None) – The preference vector to use. If not provided, defaults to \(\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m\).scale_mode (
Literal['min','median','rmse']) – The scaling mode used to build the balance transformation."min"uses the smallest eigenvalue (default),"median"uses the median eigenvalue, and"rmse"uses the mean eigenvalue (as in the original implementation).