GradVac¶
- class torchjd.aggregation.GradVac(beta=0.5, eps=1e-08)[source]¶
StatefulAggregatorimplementing the aggregation step of Gradient Vaccine (GradVac) from Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight).For each task \(i\), the order in which other tasks \(j\) are visited is drawn at random. For each pair \((i, j)\), the cosine similarity \(\phi_{ij}\) between the (possibly already modified) gradient of task \(i\) and the original gradient of task \(j\) is compared to an EMA target \(\hat{\phi}_{ij}\). When \(\phi_{ij} < \hat{\phi}_{ij}\), a closed-form correction adds a scaled copy of \(g_j\) to \(g_i^{(\mathrm{PC})}\). The EMA is then updated with \(\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}\). The aggregated vector is the sum of the modified rows.
This aggregator is stateful: it keeps \(\hat{\phi}\) across calls. Use
reset()when the number of tasks or dtype changes.- Parameters:
Note
For each task \(i\), the order of other tasks \(j\) is shuffled independently using the global PyTorch RNG (
torch.randperm). Seed it withtorch.manual_seedif you need reproducibility.
- class torchjd.aggregation.GradVacWeighting(beta=0.5, eps=1e-08)[source]¶
StatefulWeightinggiving the weights ofGradVac.All required quantities (gradient norms, cosine similarities, and their updates after the vaccine correction) are derived purely from the Gramian, without needing the full Jacobian. If \(g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k\), then:
\[\|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j}\]where \(G\) is the Gramian. The correction \(g_i^{(\mathrm{PC})} \mathrel{+}= w g_j\) then becomes \(c_{ij} \mathrel{+}= w\), and the updated dot products follow immediately.
This weighting is stateful: it keeps \(\hat{\phi}\) across calls. Use
reset()when the number of tasks or dtype changes.- Parameters: