backward

torchjd.autojac.backward.backward(tensors, aggregator, inputs=None, retain_graph=False, parallel_chunk_size=None)

Computes the Jacobian of all values in tensors with respect to all inputs. Computes its aggregation by the provided aggregator and accumulates it in the .grad fields of the inputs.

Parameters:
  • tensors (Union[Sequence[Tensor], Tensor]) – The tensor or tensors to differentiate. Should be non-empty. The Jacobian matrices will have one row for each value of each of these tensors.

  • aggregator (Aggregator) – Aggregator used to reduce the Jacobian into a vector.

  • inputs (Optional[Iterable[Tensor]]) – The tensors with respect to which the Jacobian must be computed. These must have their requires_grad flag set to True. If not provided, defaults to the leaf tensors that were used to compute the tensors parameter.

  • retain_graph (bool) – If False, the graph used to compute the grad will be freed. Defaults to False.

  • parallel_chunk_size (int | None) – The number of scalars to differentiate simultaneously in the backward pass. If set to None, all coordinates of tensors will be differentiated in parallel at once. If set to 1, all coordinates will be differentiated sequentially. A larger value results in faster differentiation, but also higher memory usage. Defaults to None.

Return type:

None

Example

The following code snippet showcases a simple usage of backward.

>>> import torch
>>>
>>> from torchjd import backward
>>> from torchjd.aggregation import UPGrad
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> backward([y1, y2], UPGrad())
>>>
>>> param.grad
tensor([0.5000, 2.5000])

The .grad field of param now contains the aggregation of the Jacobian of \(\begin{bmatrix}y_1 \\ y_2\end{bmatrix}\) with respect to param.

Warning

To differentiate in parallel, backward relies on torch.vmap, which has some limitations: it does not work on the output of compiled functions, when some tensors have retains_grad=True or when using an RNN on CUDA, for instance. If you experience issues with backward try to use parallel_chunk_size=1 to avoid relying on torch.vmap.