COSMOS

class torchjd.scalarization.COSMOS(lambda_, weights)[source]

Scalarizer that combines the input tensor of values using the COSMOS scalarization, proposed in Scalable Pareto Front Approximation for Deep Multi-Objective Learning.

It returns a linear scalarization penalized by the cosine similarity between the values and the preference vector:

\[\sum_i r_i L_i - \lambda \frac{\sum_i r_i L_i}{\lVert r \rVert \, \lVert L \rVert},\]

where:

  • \(L_i\) is the \(i\)-th input value (the \(i\)-th objective);

  • \(r_i\) is its preference weight (the weights parameter);

  • \(\lambda\) is the cosine-similarity penalty coefficient (the lambda_ parameter);

  • the subtracted term is \(\lambda \cos(r, L)\), which rewards aligning the vector of values with the preference direction and is what spreads the approximated Pareto front.

Parameters:
  • lambda_ (float) – The cosine-similarity penalty coefficient \(\lambda\). Must be non-negative. A value of 0 reduces COSMOS to a plain linear scalarization. The paper uses values ranging from 0.01 to 8 depending on the dataset, with no single best value.

  • weights (Tensor) – The preference vector \(r\) applied to the values. It must have the same shape as the values passed at call time. To approximate the whole Pareto front rather than a single trade-off, it should be re-sampled from a Dirichlet distribution and reassigned before every call, as in the paper, e.g. for m objectives cosmos.weights = torch.distributions.Dirichlet(torch.ones(m)).sample() (a uniform distribution over the probability simplex; a concentration smaller than one spreads the samples toward the corners of the simplex).

Note

The full COSMOS method also conditions the model on the preference vector by concatenating it to the input; that is a modeling choice left to the user. This scalarizer only implements the objective.

__call__(values, /)[source]

Computes the scalar value from the input tensor of values and applies all registered hooks.

Parameters:

values (Tensor) – The tensor of values to scalarize. May be of any shape.

Return type:

Tensor