GradDrop

class torchjd.aggregation.GradDrop(f=<function _identity>, leak=None)[source]

Aggregator that applies the gradient combination steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout.

Parameters:
  • f (Callable) – The function to apply to the Gradient Positive Sign Purity. It should be monotically increasing. Defaults to identity.

  • leak (Tensor | None) – The tensor of leak values, determining how much each row is allowed to leak through. Defaults to None, which means no leak.