GradVac

class torchjd.aggregation.GradVac(beta=0.5, eps=1e-08)[source]

Stateful Aggregator implementing the aggregation step of Gradient Vaccine (GradVac) from Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight).

For each task \(i\), the order in which other tasks \(j\) are visited is drawn at random. For each pair \((i, j)\), the cosine similarity \(\phi_{ij}\) between the (possibly already modified) gradient of task \(i\) and the original gradient of task \(j\) is compared to an EMA target \(\hat{\phi}_{ij}\). When \(\phi_{ij} < \hat{\phi}_{ij}\), a closed-form correction adds a scaled copy of \(g_j\) to \(g_i^{(\mathrm{PC})}\). The EMA is then updated with \(\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}\). The aggregated vector is the sum of the modified rows.

This aggregator is stateful: it keeps \(\hat{\phi}\) across calls. Use reset() when the number of tasks or dtype changes.

Parameters:
  • beta (float) – EMA decay for \(\hat{\phi}\).

  • eps (float) – Small non-negative constant added to denominators.

Note

For each task \(i\), the order of other tasks \(j\) is shuffled independently using the global PyTorch RNG (torch.randperm). Seed it with torch.manual_seed if you need reproducibility.

reset()[source]

Clears EMA state so the next forward starts from zero targets.

Return type:

None

class torchjd.aggregation.GradVacWeighting(beta=0.5, eps=1e-08)[source]

Stateful Weighting giving the weights of GradVac.

All required quantities (gradient norms, cosine similarities, and their updates after the vaccine correction) are derived purely from the Gramian, without needing the full Jacobian. If \(g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k\), then:

\[\|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j}\]

where \(G\) is the Gramian. The correction \(g_i^{(\mathrm{PC})} \mathrel{+}= w g_j\) then becomes \(c_{ij} \mathrel{+}= w\), and the updated dot products follow immediately.

This weighting is stateful: it keeps \(\hat{\phi}\) across calls. Use reset() when the number of tasks or dtype changes.

Parameters:
  • beta (float) – EMA decay for \(\hat{\phi}\).

  • eps (float) – Small non-negative constant added to denominators.

reset()[source]

Clears EMA state so the next forward starts from zero targets.

Return type:

None