aggregation¶
When doing Jacobian descent, the Jacobian matrix has to be aggregated into a vector to store in the
.grad
fields of the model parameters. The
Aggregator
is responsible for these aggregations.
When using the autogram engine, we rather need to extract a vector
of weights from the Gramian of the Jacobian. The
Weighting
is responsible for this.
Note
Most aggregators rely on computing the Gramian of the Jacobian, extracting a vector of weights
from this Gramian using a Weighting
, and then
combining the rows of the Jacobian using these weights. For all of them, we provide both the
Aggregator
interface (to be used in autojac) and
the Weighting
interface (to be used in autogram).
For the rest, we only provide the Aggregator
interface – they are not compatible with autogram.
Aggregators
and Weightings
are callables that take a Jacobian matrix or a
Gramian matrix as inputs, respectively. The following example shows how to use UPGrad to either
aggregate a Jacobian (of shape [m, n]
, where m
is the number of objectives and n
is the
number of parameters), or obtain the weights from the Gramian of the Jacobian (of shape [m, m]
).
>>> from torch import tensor
>>> from torchjd.aggregation import UPGrad, UPGradWeighting
>>>
>>> aggregator = UPGrad()
>>> jacobian = tensor([[-4.0, 1.0, 1.0], [6.0, 1.0, 1.0]])
>>> aggregation = aggregator(jacobian)
>>> aggregation
tensor([0.2929, 1.9004, 1.9004])
>>> weighting = UPGradWeighting()
>>> gramian = jacobian @ jacobian.T
>>> weights = weighting(gramian)
>>> weights
tensor([1.1109, 0.7894])
When dealing with a more general tensor of objectives, of shape [m_1, ..., m_k]
(i.e. not
necessarily a simple vector), the Jacobian will be of shape [m_1, ..., m_k, n]
, and its Gramian
will be called a generalized Gramian, of shape [m_1, ..., m_k, m_k, ..., m_1]
. One can use a
GeneralizedWeighting
to extract
a tensor of weights (of shape [m_1, ..., m_k]
) from such a generalized Gramian. The simplest
GeneralizedWeighting
is
Flattening
: it simply “flattens” the
generalized Gramian into a square Gramian matrix (of shape [m_1 * ... * m_k, m_1 * ... * m_k]
),
applies a normal weighting to it to obtain a vector of weights, and returns the reshaped tensor of
weights.
>>> from torch import ones
>>> from torchjd.aggregation import Flattening, UPGradWeighting
>>>
>>> weighting = Flattening(UPGradWeighting())
>>> # Generate a generalized Gramian filled with ones, for the sake of the example
>>> generalized_gramian = ones((2, 3, 3, 2))
>>> weights = weighting(generalized_gramian)
>>> weights
tensor([[0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667]])
Abstract base classes¶
- class torchjd.aggregation.Aggregator[source]¶
Abstract base class for all aggregators. It has the role of aggregating matrices of dimension \(m \times n\) into row vectors of dimension \(n\).
- class torchjd.aggregation.Weighting[source]¶
Abstract base class for all weighting methods. It has the role of extracting a vector of weights of dimension \(m\) from some statistic of a matrix of dimension \(m \times n\), generally its Gramian, of dimension \(m \times m\).
- class torchjd.aggregation.GeneralizedWeighting[source]¶
Abstract base class for all weightings that operate on generalized Gramians. It has the role of extracting a tensor of weights of dimension \(m_1 \times \dots \times m_k\) from a generalized Gramian of dimension \(m_1 \times \dots \times m_k \times m_k \times \dots \times m_1\).