backward

torchjd.autojac.backward(tensors, jac_tensors=None, inputs=None, retain_graph=False, parallel_chunk_size=None)[source]

Computes the Jacobians of tensors with respect to inputs, left-multiplied by jac_tensors (or identity if jac_tensors is None), and accumulates the results in the .jac fields of the inputs.

Parameters:
  • tensors (Sequence[Tensor] | Tensor) – The tensor or tensors to differentiate. Should be non-empty.

  • jac_tensors (Sequence[Tensor] | Tensor | None) – The initial Jacobians to backpropagate, analog to the grad_tensors parameter of torch.autograd.backward(). If provided, it must have the same structure as tensors and each tensor in jac_tensors must match the shape of the corresponding tensor in tensors, with an extra leading dimension representing the number of rows of the resulting Jacobian (e.g. the number of losses). All tensors in jac_tensors must have the same first dimension. If None, defaults to the identity matrix. In this case, the standard Jacobian of tensors is computed, with one row for each value in the tensors.

  • inputs (Iterable[Tensor] | None) – The tensors with respect to which the Jacobians must be computed. These must have their requires_grad flag set to True. If not provided, defaults to the leaf tensors that were used to compute the tensors parameter.

  • retain_graph (bool) – If False, the graph used to compute the grad will be freed. Defaults to False.

  • parallel_chunk_size (int | None) – The number of scalars to differentiate simultaneously in the backward pass. If set to None, all coordinates of tensors will be differentiated in parallel at once. If set to 1, all coordinates will be differentiated sequentially. A larger value results in faster differentiation, but also higher memory usage. Defaults to None.

Return type:

None

Example

This example shows a simple usage of backward.

>>> import torch
>>>
>>> from torchjd.autojac import backward
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> backward([y1, y2])
>>>
>>> param.jac
tensor([[-1.,  1.],
        [ 2.,  4.]])

The .jac field of param now contains the Jacobian of \(\begin{bmatrix}y_1 \\ y_2\end{bmatrix}\) with respect to param.

Example

This is the same example as before, except that we explicitly specify jac_tensors as the rows of the identity matrix (which is equivalent to using the default None).

>>> import torch
>>>
>>> from torchjd.autojac import backward
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> J1 = torch.tensor([1.0, 0.0])
>>> J2 = torch.tensor([0.0, 1.0])
>>>
>>> backward([y1, y2], jac_tensors=[J1, J2])
>>>
>>> param.jac
tensor([[-1.,  1.],
        [ 2.,  4.]])

Instead of using the identity jac_tensors, you can backpropagate some Jacobians obtained by a call to torchjd.autojac.jac() on a later part of the computation graph.

Warning

To differentiate in parallel, backward relies on torch.vmap, which has some limitations: it does not work on the output of compiled functions, when some tensors have retains_grad=True or when using an RNN on CUDA, for instance. If you experience issues with backward try to use parallel_chunk_size=1 to avoid relying on torch.vmap.