Krum

class torchjd.aggregation.krum.Krum(n_byzantine, n_selected=1)

Aggregator for adversarial federated learning, as defined in Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent.

Parameters:
  • n_byzantine (int) – The number of rows of the input matrix that can come from an adversarial source.

  • n_selected (int) – The number of selected rows in the context of Multi-Krum. Defaults to 1.

Example

Use Multi-Krum to aggregate a matrix with 1 adversarial row.

>>> from torch import tensor
>>> from torchjd.aggregation import Krum
>>>
>>> A = Krum(n_byzantine=1, n_selected=4)
>>> J = tensor([
...     [1.,     1., 1.],
...     [1.,     0., 1.],
...     [75., -666., 23],  # adversarial row
...     [1.,     2., 3.],
...     [2.,     0., 1.],
... ])
>>>
>>> A(J)
tensor([1.2500, 0.7500, 1.5000])