Random

class torchjd.aggregation.random.Random

Aggregator that computes a random combination of the rows of the provided matrices, as defined in algorithm 2 of Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning.

Example

Compute several random combinations of the rows of a matrix.

>>> from torch import tensor
>>> from torchjd.aggregation import Random
>>>
>>> A = Random()
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([-2.6229,  1.0000,  1.0000])
>>>
>>> A(J)
tensor([5.3976, 1.0000, 1.0000])