Random¶
- class torchjd.aggregation.random.Random¶
Aggregator
that computes a random combination of the rows of the provided matrices, as defined in algorithm 2 of Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning.Example
Compute several random combinations of the rows of a matrix.
>>> from torch import tensor >>> from torchjd.aggregation import Random >>> >>> A = Random() >>> J = tensor([[-4., 1., 1.], [6., 1., 1.]]) >>> >>> A(J) tensor([-2.6229, 1.0000, 1.0000]) >>> >>> A(J) tensor([5.3976, 1.0000, 1.0000])