Nash-MTL¶
- class torchjd.aggregation.nash_mtl.NashMTL(n_tasks, max_norm=1.0, update_weights_every=1, optim_niter=20)¶
Aggregator
as proposed in Algorithm 1 of Multi-Task Learning as a Bargaining Game.- Parameters:
n_tasks (
int
) – The number of tasks, corresponding to the number of rows in the provided matrices.max_norm (
float
) – Maximum value of the norm of \(A^T w\).update_weights_every (
int
) – A parameter determining how often the actual weighting should be performed. A larger value means that the same weights will be re-used for more calls to the aggregator.optim_niter (
int
) – The number of iterations of the underlying optimization process.
Example
Use NashMTL to aggregate a matrix.
>>> import warnings >>> warnings.filterwarnings("ignore") >>> >>> from torch import tensor >>> from torchjd.aggregation import NashMTL >>> >>> A = NashMTL(n_tasks=2) >>> J = tensor([[-4., 1., 1.], [6., 1., 1.]]) >>> >>> A(J) tensor([0.0542, 0.7061, 0.7061])
Warning
This implementation was adapted from the official implementation, which has some flaws. Use with caution.
Warning
The aggregator is stateful. Its output will thus depend not only on the input matrix, but also on its state. It thus depends on previously seen matrices. It should be reset between experiments.