DWA¶
- class torchjd.scalarization.DWA(temperature=2.0)[source]¶
StatefulScalarizerthat combines the input tensor of values using Dynamic Weight Average (DWA), proposed in End-to-End Multi-Task Learning with Attention.DWA weights each value by how quickly its loss has been decreasing relative to the others. At epoch \(t\), the current batch’s values are combined as
\[\sum_k \lambda_k(t)\, \ell_k, \qquad \lambda_k(t) = \frac{K \exp(w_k(t-1) / T)}{\sum_i \exp(w_i(t-1) / T)}, \qquad w_k(t-1) = \frac{L_k(t-1)}{L_k(t-2)}\]where:
\(\ell_k\) is the \(k\)-th value being scalarized (typically the current batch’s loss for task k);
\(L_k(t)\) is the \(k\)-th value averaged over epoch \(t\) (used only for the weights);
\(w_k(t-1)\) is the relative descending rate: the ratio of average losses over the two previous epochs;
\(T\) is the temperature; a larger \(T\) makes the weights more uniform;
\(K\) is the number of values (e.g. the number of tasks); the factor \(K\) keeps \(\sum_k \lambda_k = K\).
The weights use only the two previous epochs’ average losses, so they need no gradient. At each call, the scalarization is returned and the current batch’s losses are summed to the current epoch’s loss sums.
step()must then be called once at the end of each epoch to finalize that epoch’s average loss and roll the history forward. During the first two epochs, before two averages are available, the weights are uniform.- Parameters:
temperature (
float) – The temperature \(T\). Must be strictly positive. Larger values make the weights more uniform. The paper uses2.0.
The following example shows how to train a model with DWA. The scalarizer is called on every batch, and
step()is called once at the end of each epoch.>>> import torch >>> from torch.nn import Linear >>> >>> from torchjd.scalarization import DWA >>> >>> model = Linear(3, 2) >>> scalarizer = DWA() >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1) >>> >>> for epoch in range(3): ... for _ in range(4): # Iterate over the batches of the epoch. ... features = torch.randn(8, 3) ... losses = model(features).pow(2).mean(dim=0) # One loss per output dimension. ... loss = scalarizer(losses) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step() ... scalarizer.step() # Roll the epoch history once, at the end of the epoch.
Note
DWA weights each value by the ratio of its losses over consecutive epochs, which the paper defines as a descending rate in the range \((0, +\infty)\). The losses are therefore expected to keep a consistent, nonzero sign across epochs (they need not be positive).