backward¶
- torchjd.backward(tensors, aggregator, inputs=None, retain_graph=False, parallel_chunk_size=None)[source]¶
Computes the Jacobian of all values in
tensorswith respect to allinputs. Computes its aggregation by the providedaggregatorand accumulates it in the.gradfields of theinputs.- Parameters:
tensors (
Sequence[Tensor] |Tensor) – The tensor or tensors to differentiate. Should be non-empty. The Jacobian matrices will have one row for each value of each of these tensors.aggregator (
Aggregator) – Aggregator used to reduce the Jacobian into a vector.inputs (
Iterable[Tensor] |None) – The tensors with respect to which the Jacobian must be computed. These must have theirrequires_gradflag set toTrue. If not provided, defaults to the leaf tensors that were used to compute thetensorsparameter.retain_graph (
bool) – IfFalse, the graph used to compute the grad will be freed. Defaults toFalse.parallel_chunk_size (
int|None) – The number of scalars to differentiate simultaneously in the backward pass. If set toNone, all coordinates oftensorswill be differentiated in parallel at once. If set to1, all coordinates will be differentiated sequentially. A larger value results in faster differentiation, but also higher memory usage. Defaults toNone.
- Return type:
Example
The following code snippet showcases a simple usage of
backward.>>> import torch >>> >>> from torchjd import backward >>> from torchjd.aggregation import UPGrad >>> >>> param = torch.tensor([1., 2.], requires_grad=True) >>> # Compute arbitrary quantities that are function of param >>> y1 = torch.tensor([-1., 1.]) @ param >>> y2 = (param ** 2).sum() >>> >>> backward([y1, y2], UPGrad()) >>> >>> param.grad tensor([0.5000, 2.5000])
The
.gradfield ofparamnow contains the aggregation of the Jacobian of \(\begin{bmatrix}y_1 \\ y_2\end{bmatrix}\) with respect toparam.Warning
To differentiate in parallel,
backwardrelies ontorch.vmap, which has some limitations: it does not work on the output of compiled functions, when some tensors haveretains_grad=Trueor when using an RNN on CUDA, for instance. If you experience issues withbackwardtry to useparallel_chunk_size=1to avoid relying ontorch.vmap.