jac¶
- torchjd.autojac.jac(outputs, inputs=None, jac_outputs=None, retain_graph=False, parallel_chunk_size=None)[source]¶
Computes the Jacobians of
outputswith respect toinputs, left-multiplied byjac_outputs(or identity ifjac_outputsisNone), and returns the result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to inputthas shape[m] + t.shape.- Parameters:
outputs (
Sequence[Tensor] |Tensor) – The tensor or tensors to differentiate. Should be non-empty.inputs (
Iterable[Tensor] |None) – The tensors with respect to which the Jacobian must be computed. These must have theirrequires_gradflag set toTrue. If not provided, defaults to the leaf tensors that were used to compute theoutputsparameter.jac_outputs (
Sequence[Tensor] |Tensor|None) – The initial Jacobians to backpropagate, analog to thegrad_outputsparameter oftorch.autograd.grad(). If provided, it must have the same structure asoutputsand each tensor injac_outputsmust match the shape of the corresponding tensor inoutputs, with an extra leading dimension representing the number of rows of the resulting Jacobian (e.g. the number of losses). IfNone, defaults to the identity matrix. In this case, the standard Jacobian ofoutputsis computed, with one row for each value in theoutputs.retain_graph (
bool) – IfFalse, the graph used to compute the grad will be freed. Defaults toFalse.parallel_chunk_size (
int|None) – The number of scalars to differentiate simultaneously in the backward pass. If set toNone, all coordinates ofoutputswill be differentiated in parallel at once. If set to1, all coordinates will be differentiated sequentially. A larger value results in faster differentiation, but also higher memory usage. Defaults toNone.
- Return type:
Note
The only difference between this function and
torchjd.autojac.backward(), is that it returns the Jacobians as a tuple, whiletorchjd.autojac.backward()stores them in the.jacfields of the inputs.Example
The following example shows how to use
jac.>>> import torch >>> >>> from torchjd.autojac import jac >>> >>> param = torch.tensor([1., 2.], requires_grad=True) >>> # Compute arbitrary quantities that are function of param >>> y1 = torch.tensor([-1., 1.]) @ param >>> y2 = (param ** 2).sum() >>> >>> jacobians = jac([y1, y2], [param]) >>> >>> jacobians (tensor([[-1., 1.], [ 2., 4.]]),)
Example
The following example shows how to compute jacobians, combine them into a single Jacobian matrix, and compute its Gramian.
>>> import torch >>> >>> from torchjd.autojac import jac >>> >>> weight = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True) # shape: [2, 2] >>> bias = torch.tensor([0.5, -0.5], requires_grad=True) # shape: [2] >>> # Compute arbitrary quantities that are function of weight and bias >>> input_vec = torch.tensor([1., -1.]) >>> y1 = weight @ input_vec + bias # shape: [2] >>> y2 = (weight ** 2).sum() + (bias ** 2).sum() # shape: [] (scalar) >>> >>> jacobians = jac([y1, y2], [weight, bias]) # shapes: [3, 2, 2], [3, 2] >>> jacobian_matrices = tuple(J.flatten(1) for J in jacobians) # shapes: [3, 4], [3, 2] >>> combined_jacobian_matrix = torch.concat(jacobian_matrices, dim=1) # shape: [3, 6] >>> gramian = combined_jacobian_matrix @ combined_jacobian_matrix.T # shape: [3, 3] >>> gramian tensor([[ 3., 0., -1.], [ 0., 3., -3.], [ -1., -3., 122.]])
The obtained gramian is a symmetric matrix containing the dot products between all pairs of gradients. It’s a strong indicator of gradient norm (the diagonal elements are the squared norms of the gradients) and conflict (a negative off-diagonal value means that the gradients conflict). In fact, most aggregators base their decision entirely on the gramian.
In this case, we can see that the first two gradients (those of y1) both have a squared norm of 3, while the third gradient (that of y2) has a squared norm of 122. The first two gradients are exactly orthogonal (they have an inner product of 0), but they conflict with the third gradient (inner product of -1 and -3).
Example
This example shows how to apply chain rule using the
jac_outputsparameter to compute the Jacobian in two steps.>>> import torch >>> >>> from torchjd.autojac import jac >>> >>> x = torch.tensor([1., 2.], requires_grad=True) >>> # Compose functions: x -> h -> y >>> h = x ** 2 >>> y1 = h.sum() >>> y2 = torch.tensor([1., -1.]) @ h >>> >>> # Step 1: Compute d[y1,y2]/dh >>> jac_h = jac([y1, y2], [h])[0] # Shape: [2, 2] >>> >>> # Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx) >>> jac_x = jac(h, [x], jac_outputs=jac_h)[0] >>> >>> jac_x tensor([[ 2., 4.], [ 2., -4.]])
This two-step computation is equivalent to directly computing
jac([y1, y2], [x]).Warning
To differentiate in parallel,
jacrelies ontorch.vmap, which has some limitations: it does not work on the output of compiled functions, when some tensors haveretains_grad=Trueor when using an RNN on CUDA, for instance. If you experience issues withjactry to useparallel_chunk_size=1to avoid relying ontorch.vmap.