Sum

class torchjd.aggregation.sum.Sum

Aggregator that sums of the rows of the input matrices.

Example

Sum the rows of a matrix.

>>> from torch import tensor
>>> from torchjd.aggregation import Sum
>>>
>>> A = Sum()
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([2., 2., 2.])