FAMO

class torchjd.scalarization.FAMO(shape, min_losses=None, lr=0.025, weight_decay=0.001)[source]

Stateful Scalarizer that combines the input tensor of values using Fast Adaptive Multitask Optimization (FAMO), proposed in FAMO: Fast Adaptive Multitask Optimization.

FAMO decreases all task losses at an approximately equal rate while using only the loss values, so it never needs the per-task gradients. The values are combined as

\[c \sum_i z_i \log(\ell_i - b_i + \epsilon), \qquad z = \mathrm{softmax}(w), \qquad c = \left( \sum_i \frac{z_i}{\ell_i - b_i + \epsilon} \right)^{-1}\]

where:

  • \(\ell_i\) is the \(i\)-th value (typically the loss of task \(i\));

  • \(b_i\) is the lower bound on the \(i\)-th loss (the min_losses parameter, 0 by default);

  • \(w_i\) is the task-weighting logit of task \(i\), learned internally by FAMO;

  • \(z = \mathrm{softmax}(w)\) are the task weights;

  • \(c\) is a normalization constant (treated as a constant in the backward pass) that makes the resulting update a convex combination of the task gradients;

  • \(\epsilon\) is a small positive constant for numerical stability.

Backpropagating this scalarized loss gives FAMO’s balanced update direction for the model.

The task-weighting logits \(w\) are not learned through that backward pass. Instead, after the model has been updated, call update() with the losses recomputed on the same batch. It measures how much each loss changed across the step,

\[\delta_i = \log(\ell_i^{\text{before}} - b_i + \epsilon) - \log(\ell_i^{\text{after}} - b_i + \epsilon),\]

and takes an Adam step on \(w\) in that direction. FAMO owns this Adam internally (configured by lr and weight_decay), so you only call the scalarizer and then update(); there is no second optimizer to manage.

Parameters:
  • shape (int | Sequence[int]) – The shape of the values to scalarize, used to create one task-weighting logit per value. An int n is interpreted as the shape (n,).

  • min_losses (Tensor | None) – The per-task lower bound \(b\) subtracted from the values before the logarithm. If provided, it must have the shape given by shape. If None, zeros are used, in which case the values must be strictly positive.

  • lr (float) – Learning rate of the internal Adam that learns the task-weighting logits. Must be non-negative. The paper uses 0.025.

  • weight_decay (float) – Weight decay of the internal Adam, i.e. the paper’s regularization coefficient on the logits. Must be non-negative. Defaults to 1e-3 (as in the paper’s Algorithm 2 and in LibMTL); the official implementation uses 1e-5.

The following example shows how to do one iteration of training of a model with FAMO. The losses are recomputed on the same batch after the model step so that update() can adjust the weights.

>>> import torch
>>> from torch.nn import Linear
>>>
>>> from torchjd.scalarization import FAMO
>>>
>>> model = Linear(3, 2)
>>> scalarizer = FAMO(2)  # Move to the right device with e.g. FAMO(2).to(device="cuda")
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
>>>
>>> features = torch.randn(8, 3)
>>> losses = model(features).pow(2).mean(dim=0)  # One loss per output dimension.
>>> loss = scalarizer(losses)
>>> optimizer.zero_grad()
>>> loss.backward()
>>> optimizer.step()
>>>
>>> # Recompute the losses on the same batch, after the model update.
>>> new_losses = model(features).pow(2).mean(dim=0)
>>> scalarizer.update(new_losses)  # Updates the task weights internally.

Note

FAMO takes the logarithm of \(\ell_i - b_i\), so each value must stay strictly above its lower bound \(b_i\) (the paper assumes non-negative losses). With the default min_losses of zeros, this means the values must be strictly positive. This precondition is not enforced.

Note

This implementation was adapted from the official implementation.

update(values, /)[source]

Updates the task-weighting logits from the change in losses across the model update, by taking one step of the internal Adam. Must be called after the scalarizer has been called on the batch’s losses, with the losses recomputed on the same batch after the model step.

Return type:

None

reset()[source]

Resets the internal state.

Return type:

None

__call__(values, /)[source]

Computes the scalar value from the input tensor of values and applies all registered hooks.

Parameters:

values (Tensor) – The tensor of values to scalarize. May be of any shape.

Return type:

Tensor