Constant

class torchjd.scalarization.Constant(weights)[source]

Scalarizer that combines the input tensor of values with constant, pre-determined weights.

Parameters:

weights (Tensor) – The weights to apply to the values. Must have the same shape as the values passed at call time.

__call__(values, /)[source]

Computes the scalar value from the input tensor of values and applies all registered hooks.

Parameters:

values (Tensor) – The tensor of values to scalarize. May be of any shape.

Return type:

Tensor