Trimmed Mean¶
- class torchjd.aggregation.trimmed_mean.TrimmedMean(trim_number)¶
Aggregator
for adversarial federated learning, that trims the most extreme values of the input matrix, before averaging its rows, as defined in Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates.- Parameters:
trim_number (
int
) – The number of maximum and minimum values to remove from each column of the input matrix (note that2 * trim_number
values are removed from each column).
Example
Remove the maximum and the minimum value from each column of the
matrix
, then average the rows of the remaining matrix.>>> from torch import tensor >>> from torchjd.aggregation import TrimmedMean >>> >>> A = TrimmedMean(trim_number=1) >>> J = tensor([ ... [ 1e11, 3], ... [ 1, -1e11], ... [-1e10, 1e10], ... [ 2, 2], ... ]) >>> >>> A(J) tensor([1.5000, 2.5000])