Contents Menu Expand Light mode Dark mode Auto light/dark mode
TorchJD
Light Logo Dark Logo

Getting Started

  • Installation
  • Examples
    • Basic Usage
    • Instance-Wise Risk Minimization (IWRM)
    • Partial Jacobian Descent for IWRM
    • Multi-Task Learning (MTL)
    • Instance-Wise Multi-Task Learning (IWMTL)
    • Recurrent Neural Network (RNN)
    • Monitoring aggregations
    • PyTorch Lightning Integration
    • Automatic Mixed Precision (AMP)

API Reference

  • autogram
    • Engine
  • autojac
    • backward
    • mtl_backward
  • aggregation
    • UPGrad
    • Aligned-MTL
    • CAGrad
    • ConFIG
    • Constant
    • DualProj
    • Flattening
    • GradDrop
    • IMTL-G
    • Krum
    • Mean
    • MGDA
    • Nash-MTL
    • PCGrad
    • Random
    • Sum
    • Trimmed Mean
Back to top

Trimmed Mean¶

class torchjd.aggregation.TrimmedMean(trim_number)[source]¶

Aggregator for adversarial federated learning, that trims the most extreme values of the input matrix, before averaging its rows, as defined in Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates.

Parameters:

trim_number (int) – The number of maximum and minimum values to remove from each column of the input matrix (note that 2 * trim_number values are removed from each column).

Previous
Sum
Copyright © Valerian Rey, Pierre Quinton
Made with Sphinx and @pradyunsg's Furo