Mean

class torchjd.aggregation.Mean[source]

Aggregator that averages the rows of the input matrices.

__call__(matrix, /)[source]

Computes the aggregation from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The Jacobian to aggregate.

Return type:

Tensor

class torchjd.aggregation.MeanWeighting[source]

Weighting that gives the weights \(\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m\).

__call__(matrix, /)[source]

Computes the vector of weights from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The matrix from which the weights must be extracted.

Return type:

Tensor