backward

torchjd.autojac.backward.backward(tensors, inputs, A, retain_graph=False, parallel_chunk_size=None)

Computes the Jacobian of all values in tensors with respect to all inputs. Computes its aggregation by A and accumulates it in the .grad fields of the inputs.

Parameters:
  • tensors (Union[Sequence[Tensor], Tensor]) – The tensor or tensors to differentiate. Should be non-empty. The Jacobian matrices will have one row for each value of each of these tensors.

  • inputs (Iterable[Tensor]) – The tensors with respect to which the Jacobian must be computed. These must have their requires_grad flag set to True.

  • A (Aggregator) – Aggregator used to reduce the Jacobian into a vector.

  • retain_graph (bool) – If False, the graph used to compute the grad will be freed. Defaults to False.

  • parallel_chunk_size (int | None) – The number of scalars to differentiate simultaneously in the backward pass. If set to None, all coordinates of tensors will be differentiated in parallel at once. If set to 1, all coordinates will be differentiated sequentially. A larger value results in faster differentiation, but also higher memory usage. Defaults to None. If parallel_chunk_size is not large enough to differentiate all tensors simultaneously, retain_graph has to be set to True.

Return type:

None

Example

The following code snippet showcases a simple usage of backward.

>>> import torch
>>>
>>> from torchjd import backward
>>> from torchjd.aggregation import UPGrad
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> backward([y1, y2], [param], A=UPGrad())
>>>
>>> param.grad
tensor([0.5000, 2.5000])

The .grad field of param now contains the aggregation of the Jacobian of \(\begin{bmatrix}y_1 \\ y_2\end{bmatrix}\) with respect to param.

Warning

backward relies on a usage of torch.vmap that is not compatible with compiled functions. The arguments of backward should thus not come from a compiled model. Check https://github.com/pytorch/pytorch/issues/138422 for the status of this issue.