FAMO¶
- class torchjd.scalarization.FAMO(shape, min_losses=None, lr=0.025, weight_decay=0.001)[source]¶
StatefulScalarizerthat combines the input tensor of values using Fast Adaptive Multitask Optimization (FAMO), proposed in FAMO: Fast Adaptive Multitask Optimization.FAMO decreases all task losses at an approximately equal rate while using only the loss values, so it never needs the per-task gradients. The values are combined as
\[c \sum_i z_i \log(\ell_i - b_i + \epsilon), \qquad z = \mathrm{softmax}(w), \qquad c = \left( \sum_i \frac{z_i}{\ell_i - b_i + \epsilon} \right)^{-1}\]where:
\(\ell_i\) is the \(i\)-th value (typically the loss of task \(i\));
\(b_i\) is the lower bound on the \(i\)-th loss (the
min_lossesparameter,0by default);\(w_i\) is the task-weighting logit of task \(i\), learned internally by FAMO;
\(z = \mathrm{softmax}(w)\) are the task weights;
\(c\) is a normalization constant (treated as a constant in the backward pass) that makes the resulting update a convex combination of the task gradients;
\(\epsilon\) is a small positive constant for numerical stability.
Backpropagating this scalarized loss gives FAMO’s balanced update direction for the model.
The task-weighting logits \(w\) are not learned through that backward pass. Instead, after the model has been updated, call
update()with the losses recomputed on the same batch. It measures how much each loss changed across the step,\[\delta_i = \log(\ell_i^{\text{before}} - b_i + \epsilon) - \log(\ell_i^{\text{after}} - b_i + \epsilon),\]and takes an Adam step on \(w\) in that direction. FAMO owns this
Adaminternally (configured bylrandweight_decay), so you only call the scalarizer and thenupdate(); there is no second optimizer to manage.- Parameters:
shape (
int|Sequence[int]) – The shape of the values to scalarize, used to create one task-weighting logit per value. Anintnis interpreted as the shape(n,).min_losses (
Tensor|None) – The per-task lower bound \(b\) subtracted from the values before the logarithm. If provided, it must have the shape given byshape. IfNone, zeros are used, in which case the values must be strictly positive.lr (
float) – Learning rate of the internalAdamthat learns the task-weighting logits. Must be non-negative. The paper uses0.025.weight_decay (
float) – Weight decay of the internalAdam, i.e. the paper’s regularization coefficient on the logits. Must be non-negative. Defaults to1e-3(as in the paper’s Algorithm 2 and in LibMTL); the official implementation uses1e-5.
The following example shows how to do one iteration of training of a model with FAMO. The losses are recomputed on the same batch after the model step so that
update()can adjust the weights.>>> import torch >>> from torch.nn import Linear >>> >>> from torchjd.scalarization import FAMO >>> >>> model = Linear(3, 2) >>> scalarizer = FAMO(2) # Move to the right device with e.g. FAMO(2).to(device="cuda") >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1) >>> >>> features = torch.randn(8, 3) >>> losses = model(features).pow(2).mean(dim=0) # One loss per output dimension. >>> loss = scalarizer(losses) >>> optimizer.zero_grad() >>> loss.backward() >>> optimizer.step() >>> >>> # Recompute the losses on the same batch, after the model update. >>> new_losses = model(features).pow(2).mean(dim=0) >>> scalarizer.update(new_losses) # Updates the task weights internally.
Note
FAMO takes the logarithm of \(\ell_i - b_i\), so each value must stay strictly above its lower bound \(b_i\) (the paper assumes non-negative losses). With the default
min_lossesof zeros, this means the values must be strictly positive. This precondition is not enforced.Note
This implementation was adapted from the official implementation.
- update(values, /)[source]¶
Updates the task-weighting logits from the change in losses across the model update, by taking one step of the internal
Adam. Must be called after the scalarizer has been called on the batch’s losses, with the losses recomputed on the same batch after the model step.- Return type: