MGDA

class torchjd.aggregation.MGDA(epsilon=0.001, max_iters=100)[source]

Aggregator performing the gradient aggregation step of Multiple-gradient descent algorithm (MGDA) for multiobjective optimization. The implementation is based on Algorithm 2 of Multi-Task Learning as Multi-Objective Optimization.

Parameters:
  • epsilon (float) – The value of \(\hat{\gamma}\) below which we stop the optimization.

  • max_iters (int) – The maximum number of iterations of the optimization loop.

__call__(matrix, /)[source]

Computes the aggregation from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The Jacobian to aggregate.

Return type:

Tensor

class torchjd.aggregation.MGDAWeighting(epsilon=0.001, max_iters=100)[source]

Weighting giving the weights of MGDA.

Parameters:
  • epsilon (float) – The value of \(\hat{\gamma}\) below which we stop the optimization.

  • max_iters (int) – The maximum number of iterations of the optimization loop.

__call__(gramian, /)[source]

Computes the vector of weights from the input gramian and applies all registered hooks.

Parameters:

gramian (Tensor) – The gramian from which the weights must be extracted.

Return type:

Tensor