Constant

class torchjd.aggregation.Constant(weights)[source]

Aggregator that makes a linear combination of the rows of the provided matrix, with constant, pre-determined weights.

Parameters:

weights (Tensor) – The weights associated to the rows of the input matrices.

__call__(matrix, /)[source]

Computes the aggregation from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The Jacobian to aggregate.

Return type:

Tensor

class torchjd.aggregation.ConstantWeighting(weights)[source]

Weighting that returns constant, pre-determined weights.

Parameters:

weights (Tensor) – The weights to return at each call.

__call__(matrix, /)[source]

Computes the vector of weights from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The matrix from which the weights must be extracted.

Return type:

Tensor