Aligned-MTL

class torchjd.aggregation.AlignedMTL(pref_vector=None, scale_mode='min')[source]

Aggregator as defined in Algorithm 1 of Independent Component Alignment for Multi-Task Learning.

Parameters:
  • pref_vector (Tensor | None) – The preference vector to use. If not provided, defaults to \(\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m\).

  • scale_mode (Literal['min', 'median', 'rmse']) – The scaling mode used to build the balance transformation. "min" uses the smallest eigenvalue (default), "median" uses the median eigenvalue, and "rmse" uses the mean eigenvalue (as in the original implementation).

Note

This implementation was adapted from the official implementation of SamsungLabs/MTL, which is not available anymore at the time of writing.

__call__(matrix, /)[source]

Computes the aggregation from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The Jacobian to aggregate.

Return type:

Tensor

class torchjd.aggregation.AlignedMTLWeighting(pref_vector=None, scale_mode='min')[source]

Weighting giving the weights of AlignedMTL.

Parameters:
  • pref_vector (Tensor | None) – The preference vector to use. If not provided, defaults to \(\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m\).

  • scale_mode (Literal['min', 'median', 'rmse']) – The scaling mode used to build the balance transformation. "min" uses the smallest eigenvalue (default), "median" uses the median eigenvalue, and "rmse" uses the mean eigenvalue (as in the original implementation).

__call__(gramian, /)[source]

Computes the vector of weights from the input gramian and applies all registered hooks.

Parameters:

gramian (Tensor) – The gramian from which the weights must be extracted.

Return type:

Tensor