Trimmed Mean

class torchjd.aggregation.TrimmedMean(trim_number)[source]

Aggregator for adversarial federated learning, that trims the most extreme values of the input matrix, before averaging its rows, as defined in Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates.

Parameters:

trim_number (int) – The number of maximum and minimum values to remove from each column of the input matrix (note that 2 * trim_number values are removed from each column).

__call__(matrix, /)[source]

Computes the aggregation from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The Jacobian to aggregate.

Return type:

Tensor