Constant¶
- class torchjd.aggregation.constant.Constant(weights)¶
Aggregator
that makes a linear combination of the rows of the provided matrix, with constant, pre-determined weights.- Parameters:
weights (
Tensor
) – The weights associated to the rows of the input matrices.
Example
Compute a linear combination of the rows of a matrix.
>>> from torch import tensor >>> from torchjd.aggregation import Constant >>> >>> A = Constant(tensor([1., 2.])) >>> J = tensor([[-4., 1., 1.], [6., 1., 1.]]) >>> >>> A(J) tensor([8., 3., 3.])