Krum

class torchjd.aggregation.Krum(n_byzantine, n_selected=1)[source]

Aggregator for adversarial federated learning, as defined in Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent.

Parameters:
  • n_byzantine (int) – The number of rows of the input matrix that can come from an adversarial source.

  • n_selected (int) – The number of selected rows in the context of Multi-Krum. Defaults to 1.

__call__(matrix, /)[source]

Computes the aggregation from the input matrix and applies all registered hooks.

Parameters:

matrix (Tensor) – The Jacobian to aggregate.

Return type:

Tensor

class torchjd.aggregation.KrumWeighting(n_byzantine, n_selected=1)[source]

Weighting giving the weights of Krum.

Parameters:
  • n_byzantine (int) – The number of rows of the input matrix that can come from an adversarial source.

  • n_selected (int) – The number of selected rows in the context of Multi-Krum. Defaults to 1.

__call__(gramian, /)[source]

Computes the vector of weights from the input gramian and applies all registered hooks.

Parameters:

gramian (Tensor) – The gramian from which the weights must be extracted.

Return type:

Tensor