GradDrop

class torchjd.aggregation.graddrop.GradDrop(f=<function _identity>, leak=None)

Aggregator that applies the gradient combination steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout.

Parameters:
  • f (Callable) – The function to apply to the Gradient Positive Sign Purity. It should be monotically increasing. Defaults to identity.

  • leak (Tensor | None) – The tensor of leak values, determining how much each row is allowed to leak through. Defaults to None, which means no leak.

Example

Use GradDrop to aggregate a matrix.

>>> from torch import tensor
>>> from torchjd.aggregation import GradDrop
>>>
>>> A = GradDrop()
>>> J = tensor([[-4., 1., 1.], [6., 1., 1.]])
>>>
>>> A(J)
tensor([6., 2., 2.])