backward¶
- torchjd.autojac.backward.backward(tensors, inputs, A, retain_graph=False, parallel_chunk_size=None)¶
Computes the Jacobian of all values in
tensors
with respect to allinputs
. Computes its aggregation byA
and accumulates it in the.grad
fields of theinputs
.- Parameters:
tensors (
Union
[Sequence
[Tensor
],Tensor
]) – The tensor or tensors to differentiate. Should be non-empty. The Jacobian matrices will have one row for each value of each of these tensors.inputs (
Iterable
[Tensor
]) – The tensors with respect to which the Jacobian must be computed. These must have theirrequires_grad
flag set toTrue
.A (
Aggregator
) – Aggregator used to reduce the Jacobian into a vector.retain_graph (
bool
) – IfFalse
, the graph used to compute the grad will be freed. Defaults toFalse
.parallel_chunk_size (
int
|None
) – The number of scalars to differentiate simultaneously in the backward pass. If set toNone
, all coordinates oftensors
will be differentiated in parallel at once. If set to1
, all coordinates will be differentiated sequentially. A larger value results in faster differentiation, but also higher memory usage. Defaults toNone
. Ifparallel_chunk_size
is not large enough to differentiate all tensors simultaneously,retain_graph
has to be set toTrue
.
- Return type:
Example
The following code snippet showcases a simple usage of
backward
.>>> import torch >>> >>> from torchjd import backward >>> from torchjd.aggregation import UPGrad >>> >>> param = torch.tensor([1., 2.], requires_grad=True) >>> # Compute arbitrary quantities that are function of param >>> y1 = torch.tensor([-1., 1.]) @ param >>> y2 = (param ** 2).sum() >>> >>> backward([y1, y2], [param], A=UPGrad()) >>> >>> param.grad tensor([0.5000, 2.5000])
The
.grad
field ofparam
now contains the aggregation of the Jacobian of \(\begin{bmatrix}y_1 \\ y_2\end{bmatrix}\) with respect toparam
.Warning
backward
relies on a usage oftorch.vmap
that is not compatible with compiled functions. The arguments ofbackward
should thus not come from a compiled model. Check https://github.com/pytorch/pytorch/issues/138422 for the status of this issue.