IMTL-L

class torchjd.scalarization.IMTLL(shape)[source]

Stateful Scalarizer that combines the input tensor of values using learned per-task scales. IMTL-L is the loss-balancing variant of Impartial Multi-Task Learning, proposed in Towards Impartial Multi-Task Learning.

Each value \(L_i\) is assigned a learnable scale parameter \(s_i\), and the values are combined as

\[\sum_i \left( e^{s_i} L_i - s_i \right)\]

where:

  • \(L_i\) is the \(i\)-th value (typically the loss of task \(i\));

  • \(s_i\) is the learnable scale parameter of task \(i\).

The factor \(e^{s_i}\) rescales each loss so that the scaled losses stay at a comparable magnitude across tasks, while the \(- s_i\) term is a regularizer that prevents the trivial solution \(s_i \to -\infty\). The \(s_i\) are stored as an nn.Parameter, so the parameters of this scalarizer must be passed to the optimizer to be learned jointly with the model.

Although it is derived without any distribution assumption (unlike UW, which is derived from Gaussian/Laplace likelihoods), IMTL-L is in fact almost equivalent to UW: this scalarization equals \(2\,\mathrm{UW}\) evaluated at the negated parameter, so the two differ only by a constant factor of two and the sign convention of the learned parameter, and share the same per-task weighting and the same optima.

The complementary gradient-balancing variant (IMTL-G) is provided as the IMTLG aggregator.

Parameters:

shape (int | Sequence[int]) – The shape of the values to scalarize, used to create one scale per value. An int n is interpreted as the shape (n,).

The following example shows how to train a model with Impartial Multi-Task Learning (loss balance), as described in the paper.

>>> import torch
>>> from torch.nn import Linear
>>>
>>> from torchjd.scalarization import IMTLL
>>>
>>> model = Linear(3, 2)
>>> scalarizer = IMTLL(2)  # Move to the right device with e.g. IMTLL(2).to(device="cuda")
>>> optimizer = torch.optim.SGD([*model.parameters(), *scalarizer.parameters()], lr=0.1)
>>>
>>> features = torch.randn(8, 3)
>>> # Compute some dummy losses just for the sake of the example
>>> losses = model(features).pow(2).mean(dim=0)  # One loss per output dimension.
>>> loss = scalarizer(losses)
>>> loss.backward()
>>> optimizer.step()

Note

The scales are initialized to 0, so at the start of training the scalarization reduces to the plain sum of the values (since \(e^0 = 1\)). Following the paper, IMTL-L is designed to balance positive losses.

reset()[source]

Resets the internal state.

Return type:

None

__call__(values, /)[source]

Computes the scalar value from the input tensor of values and applies all registered hooks.

Parameters:

values (Tensor) – The tensor of values to scalarize. May be of any shape.

Return type:

Tensor