Contents Menu Expand Light mode Dark mode Auto light/dark mode
TorchJD
Light Logo Dark Logo

Getting Started

  • Installation
  • Examples
    • Basic Usage
    • Instance-Wise Risk Minimization (IWRM)
    • Partial Jacobian Descent for IWRM
    • Multi-Task Learning (MTL)
    • Instance-Wise Multi-Task Learning (IWMTL)
    • Recurrent Neural Network (RNN)
    • Monitoring aggregations
    • PyTorch Lightning Integration
    • Automatic Mixed Precision (AMP)

API Reference

  • autogram
    • Engine
  • autojac
    • backward
    • mtl_backward
  • aggregation
    • UPGrad
    • Aligned-MTL
    • CAGrad
    • ConFIG
    • Constant
    • DualProj
    • Flattening
    • GradDrop
    • IMTL-G
    • Krum
    • Mean
    • MGDA
    • Nash-MTL
    • PCGrad
    • Random
    • Sum
    • Trimmed Mean
Back to top

ConFIG¶

class torchjd.aggregation.ConFIG(pref_vector=None)[source]¶

Aggregator as defined in Equation 2 of ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks.

Parameters:

pref_vector (Tensor | None) – The preference vector used to weight the rows. If not provided, defaults to equal weights of 1.

Note

This implementation was adapted from the official implementation.

Next
Constant
Previous
CAGrad
Copyright © Valerian Rey, Pierre Quinton
Made with Sphinx and @pradyunsg's Furo