Flattening

class torchjd.aggregation.Flattening(weighting)[source]

GeneralizedWeighting flattening the generalized Gramian into a square matrix, extracting a vector of weights from it using a Weighting, and returning the reshaped tensor of weights.

For instance, when applied to a generalized Gramian of shape [2, 3, 3, 2], it would flatten it into a square Gramian matrix of shape [6, 6], apply the weighting on it to get a vector of weights of shape [6], and then return this vector reshaped into a matrix of shape [2, 3].

Parameters:

weighting (Weighting) – The weighting to apply to the Gramian matrix.

__call__(generalized_gramian, /)[source]

Computes the tensor of weights from the input generalized Gramian and applies all registered hooks.

Parameters:

generalized_gramian (Tensor) – The tensor from which the weights must be extracted.

Return type:

Tensor