UW

class torchjd.scalarization.UW(shape)[source]

Stateful Scalarizer that combines the input tensor of values using learned per-task uncertainties. UW is short for Uncertainty Weighting, the method proposed in Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics.

Each value \(L_i\) is assigned a learnable log-variance \(s_i\), and the values are combined as

\[\sum_i \left( \frac{1}{2} e^{-s_i} L_i + \frac{1}{2} s_i \right)\]

where:

  • \(L_i\) is the \(i\)-th value (typically the loss of task \(i\));

  • \(s_i = \log \sigma_i^2\) is the learnable log-variance of task \(i\).

Following the paper, the log-variance \(s_i\) is learned rather than the variance \(\sigma_i^2\) directly: this is numerically more stable (the combination never divides by zero) and keeps \(s_i\) unconstrained, since \(e^{-s_i}\) is always positive. The \(s_i\) are stored as an nn.Parameter, so the parameters of this scalarizer must be passed to the optimizer to be learned jointly with the model.

Parameters:

shape (int | Sequence[int]) – The shape of the values to scalarize, used to create one log-variance per value. An int n is interpreted as the shape (n,).

The following example shows train a model with Uncertainty Weighting, as described in the paper.

>>> import torch
>>> from torch.nn import Linear
>>>
>>> from torchjd.scalarization import UW
>>>
>>> model = Linear(3, 2)
>>> scalarizer = UW(2)  # Move to the right device with e.g. UW(2).to(device="cuda")
>>> optimizer = torch.optim.SGD([*model.parameters(), *scalarizer.parameters()], lr=0.1)
>>>
>>> features = torch.randn(8, 3)
>>> # Compute some dummy losses just for the sake of the example
>>> losses = model(features).pow(2).mean(dim=0)  # One loss per output dimension.
>>> loss = scalarizer(losses)
>>> loss.backward()
>>> optimizer.step()

Note

The log-variances are initialized to 0 (i.e. \(\sigma_i^2 = 1\)), which gives uniform weights at the start of training. The paper reports that the result is robust to this initialization. (LibMTL initializes them to -0.5 instead.)

reset()[source]

Resets the internal state.

Return type:

None

__call__(values, /)[source]

Computes the scalar value from the input tensor of values and applies all registered hooks.

Parameters:

values (Tensor) – The tensor of values to scalarize. May be of any shape.

Return type:

Tensor