Mean

class torchjd.aggregation.mean.Mean

Aggregator that averages the rows of the input matrices.

Example

Average the rows of a matrix

>>> from torch import tensor
>>> from torchjd.aggregation import Mean
>>>
>>> A = Mean()
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([1., 1., 1.])