jac_to_grad

torchjd.autojac.jac_to_grad(tensors, /, aggregator, *, retain_jac=False, optimize_gramian_computation=False)[source]

Aggregates the Jacobians stored in the .jac fields of tensors and accumulates the result into their .grad fields.

Parameters:
  • tensors (Iterable[Tensor]) – The tensors whose .jac fields should be aggregated. All Jacobians must have the same first dimension (e.g. number of losses).

  • aggregator (Aggregator) – The aggregator used to reduce the Jacobians into gradients. If it uses a Weighting to combine the rows of the Jacobians, jac_to_grad will also return the computed weights.

  • retain_jac (bool) – Whether to preserve the .jac fields of the tensors after they have been used. Defaults to False.

  • optimize_gramian_computation (bool) – When the aggregator computes weights based on the Gramian of the Jacobian, it’s possible to skip the concatenation of the Jacobians and to instead compute the Gramian as the sum of the Gramians of the individual Jacobians. This saves memory (up to 50% memory saving) but can be slightly slower (up to 15%) on CUDA. We advise to try this optimization if memory is an issue for you. Defaults to False.

Return type:

Tensor | None

Note

When optimize_gramian_computation=False, this function starts by “flattening” the .jac fields into matrices (i.e. flattening all of their dimensions except the first one), then concatenates those matrices into a combined Jacobian matrix. The aggregator is then used on this matrix, which returns a combined gradient vector, that is split and reshaped to fit into the .grad fields of the tensors.

Note

When optimize_gramian_computation=True, this function computes and sums the Gramian of each individual .jac field, iteratively. The inner weighting of the aggregator is then used to extract some weights from the obtained Gramian, used to compute a linear combination of the rows of each .jac field, to be stored into the corresponding .grad field. This is mathematically equivalent to the approach with optimize_gramian_computation=False, but saves memory by not having to hold the concatenated Jacobian matrix in memory at any time.

Example

This example shows how to use jac_to_grad after a call to backward

>>> import torch
>>>
>>> from torchjd.aggregation import UPGrad
>>> from torchjd.autojac import backward, jac_to_grad
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> backward([y1, y2])  # param now has a .jac field
>>> weights = jac_to_grad([param], UPGrad())  # param now has a .grad field
>>> param.grad
tensor([0.5000, 2.5000])
>>> weights
tensor([0.5,  0.5])

The .grad field of param now contains the aggregation (by UPGrad) of the Jacobian of \(\begin{bmatrix}y_1 \\ y_2\end{bmatrix}\) with respect to param. In this case, the weights used to combine the Jacobian are equal because there was no conflict.