GradDrop¶
- class torchjd.aggregation.graddrop.GradDrop(f=<function _identity>, leak=None)¶
Aggregator
that applies the gradient combination steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout.- Parameters:
Example
Use GradDrop to aggregate a matrix.
>>> from torch import tensor >>> from torchjd.aggregation import GradDrop >>> >>> A = GradDrop() >>> J = tensor([[-4., 1., 1.], [6., 1., 1.]]) >>> >>> A(J) tensor([6., 2., 2.])