Random

class torchjd.aggregation.Random[source]

Aggregator that computes a random combination of the rows of the provided matrices, as defined in algorithm 2 of Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning.

__call__(matrix, /)[source]

Computes the aggregation from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The Jacobian to aggregate.

Return type:

Tensor

class torchjd.aggregation.RandomWeighting[source]

Weighting that generates positive random weights at each call.

__call__(matrix, /)[source]

Computes the vector of weights from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The matrix from which the weights must be extracted.

Return type:

Tensor