Constant

class torchjd.aggregation.constant.Constant(weights)

Aggregator that makes a linear combination of the rows of the provided matrix, with constant, pre-determined weights.

Parameters:

weights (Tensor) – The weights associated to the rows of the input matrices.

Example

Compute a linear combination of the rows of a matrix.

>>> from torch import tensor
>>> from torchjd.aggregation import Constant
>>>
>>> A = Constant(tensor([1., 2.]))
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([8., 3., 3.])