Trimmed Mean

class torchjd.aggregation.trimmed_mean.TrimmedMean(trim_number)

Aggregator for adversarial federated learning, that trims the most extreme values of the input matrix, before averaging its rows, as defined in Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates.

Parameters:

trim_number (int) – The number of maximum and minimum values to remove from each column of the input matrix (note that 2 * trim_number values are removed from each column).

Example

Remove the maximum and the minimum value from each column of the matrix, then average the rows of the remaining matrix.

>>> from torch import tensor
>>> from torchjd.aggregation import TrimmedMean
>>>
>>> A = TrimmedMean(trim_number=1)
>>> J = tensor([
...     [ 1e11,     3],
...     [    1, -1e11],
...     [-1e10,  1e10],
...     [    2,     2],
... ])
>>>
>>> A(J)
tensor([1.5000, 2.5000])