Mean¶
- class torchjd.aggregation.mean.Mean¶
Aggregator
that averages the rows of the input matrices.Example
Average the rows of a matrix
>>> from torch import tensor >>> from torchjd.aggregation import Mean >>> >>> A = Mean() >>> J = tensor([[-4., 1., 1.], [6., 1., 1.]]) >>> >>> A(J) tensor([1., 1., 1.])