aggregation

When doing Jacobian descent, the Jacobian matrix has to be aggregated into a vector to store in the .grad fields of the model parameters. The Aggregator is responsible for these aggregations.

When using the autogram engine, we rather need to extract a vector of weights from the Gramian of the Jacobian. The Weighting is responsible for this.

Note

Most aggregators rely on computing the Gramian of the Jacobian, extracting a vector of weights from this Gramian using a Weighting, and then combining the rows of the Jacobian using these weights. For all of them, we provide both the Aggregator interface (to be used in autojac) and the Weighting interface (to be used in autogram). For the rest, we only provide the Aggregator interface – they are not compatible with autogram.

Aggregators and Weightings are callables that take a Jacobian matrix or a Gramian matrix as inputs, respectively. The following example shows how to use UPGrad to either aggregate a Jacobian (of shape [m, n], where m is the number of objectives and n is the number of parameters), or obtain the weights from the Gramian of the Jacobian (of shape [m, m]).

>>> from torch import tensor
>>> from torchjd.aggregation import UPGrad, UPGradWeighting
>>>
>>> aggregator = UPGrad()
>>> jacobian = tensor([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
>>> aggregation = aggregator(jacobian)
>>> aggregation
tensor([0.2929, 1.9004, 1.9004])
>>> weighting = UPGradWeighting()
>>> gramian = jacobian @ jacobian.T
>>> weights = weighting(gramian)
>>> weights
tensor([1.1109, 0.7894])

When dealing with a more general tensor of objectives, of shape [m_1, ..., m_k] (i.e. not necessarily a simple vector), the Jacobian will be of shape [m_1, ..., m_k, n], and its Gramian will be called a generalized Gramian, of shape [m_1, ..., m_k, m_k, ..., m_1]. One can use a GeneralizedWeighting to extract a tensor of weights (of shape [m_1, ..., m_k]) from such a generalized Gramian. The simplest GeneralizedWeighting is Flattening: it simply “flattens” the generalized Gramian into a square Gramian matrix (of shape [m_1 * ... * m_k, m_1 * ... * m_k]), applies a normal weighting to it to obtain a vector of weights, and returns the reshaped tensor of weights.

>>> from torch import ones
>>> from torchjd.aggregation import Flattening, UPGradWeighting
>>>
>>> weighting = Flattening(UPGradWeighting())
>>> # Generate a generalized Gramian filled with ones, for the sake of the example
>>> generalized_gramian = ones((2, 3, 3, 2))
>>> weights = weighting(generalized_gramian)
>>> weights
tensor([[0.1667, 0.1667, 0.1667],
        [0.1667, 0.1667, 0.1667]])

Abstract base classes

class torchjd.aggregation.Aggregator[source]

Abstract base class for all aggregators. It has the role of aggregating matrices of dimension \(m \times n\) into row vectors of dimension \(n\).

class torchjd.aggregation.Weighting[source]

Abstract base class for all weighting methods. It has the role of extracting a vector of weights of dimension \(m\) from some statistic of a matrix of dimension \(m \times n\), generally its Gramian, of dimension \(m \times m\).

class torchjd.aggregation.GeneralizedWeighting[source]

Abstract base class for all weightings that operate on generalized Gramians. It has the role of extracting a tensor of weights of dimension \(m_1 \times \dots \times m_k\) from a generalized Gramian of dimension \(m_1 \times \dots \times m_k \times m_k \times \dots \times m_1\).