MGDA¶
- class torchjd.aggregation.mgda.MGDA(epsilon=0.001, max_iters=100)¶
Aggregator
performing the gradient aggregation step of Multiple-gradient descent algorithm (MGDA) for multiobjective optimization. The implementation is based on Algorithm 2 of Multi-Task Learning as Multi-Objective Optimization.- Parameters:
Example
Use MGDA to aggregate a matrix.
>>> from torch import tensor >>> from torchjd.aggregation import MGDA >>> >>> A = MGDA() >>> J = tensor([[-4., 1., 1.], [6., 1., 1.]]) >>> >>> A(J) tensor([1.1921e-07, 1.0000e+00, 1.0000e+00])