MGDA

class torchjd.aggregation.mgda.MGDA(epsilon=0.001, max_iters=100)

Aggregator performing the gradient aggregation step of Multiple-gradient descent algorithm (MGDA) for multiobjective optimization. The implementation is based on Algorithm 2 of Multi-Task Learning as Multi-Objective Optimization.

Parameters:
  • epsilon (float) – The value of \(\hat{\gamma}\) below which we stop the optimization.

  • max_iters (int) – The maximum number of iterations of the optimization loop.

Example

Use MGDA to aggregate a matrix.

>>> from torch import tensor
>>> from torchjd.aggregation import MGDA
>>>
>>> A = MGDA()
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([1.1921e-07, 1.0000e+00, 1.0000e+00])