jac

torchjd.autojac.jac(outputs, inputs=None, jac_outputs=None, retain_graph=False, parallel_chunk_size=None)[source]

Computes the Jacobians of outputs with respect to inputs, left-multiplied by jac_outputs (or identity if jac_outputs is None), and returns the result as a tuple, with one Jacobian per input tensor. The returned Jacobian with respect to input t has shape [m] + t.shape.

Parameters:
  • outputs (Sequence[Tensor] | Tensor) – The tensor or tensors to differentiate. Should be non-empty.

  • inputs (Iterable[Tensor] | None) – The tensors with respect to which the Jacobian must be computed. These must have their requires_grad flag set to True. If not provided, defaults to the leaf tensors that were used to compute the outputs parameter.

  • jac_outputs (Sequence[Tensor] | Tensor | None) – The initial Jacobians to backpropagate, analog to the grad_outputs parameter of torch.autograd.grad(). If provided, it must have the same structure as outputs and each tensor in jac_outputs must match the shape of the corresponding tensor in outputs, with an extra leading dimension representing the number of rows of the resulting Jacobian (e.g. the number of losses). If None, defaults to the identity matrix. In this case, the standard Jacobian of outputs is computed, with one row for each value in the outputs.

  • retain_graph (bool) – If False, the graph used to compute the grad will be freed. Defaults to False.

  • parallel_chunk_size (int | None) – The number of scalars to differentiate simultaneously in the backward pass. If set to None, all coordinates of outputs will be differentiated in parallel at once. If set to 1, all coordinates will be differentiated sequentially. A larger value results in faster differentiation, but also higher memory usage. Defaults to None.

Return type:

tuple[Tensor, ...]

Note

The only difference between this function and torchjd.autojac.backward(), is that it returns the Jacobians as a tuple, while torchjd.autojac.backward() stores them in the .jac fields of the inputs.

Example

The following example shows how to use jac.

>>> import torch
>>>
>>> from torchjd.autojac import jac
>>>
>>> param = torch.tensor([1., 2.], requires_grad=True)
>>> # Compute arbitrary quantities that are function of param
>>> y1 = torch.tensor([-1., 1.]) @ param
>>> y2 = (param ** 2).sum()
>>>
>>> jacobians = jac([y1, y2], [param])
>>>
>>> jacobians
(tensor([[-1., 1.],
        [ 2., 4.]]),)

Example

The following example shows how to compute jacobians, combine them into a single Jacobian matrix, and compute its Gramian.

>>> import torch
>>>
>>> from torchjd.autojac import jac
>>>
>>> weight = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)  # shape: [2, 2]
>>> bias = torch.tensor([0.5, -0.5], requires_grad=True)  # shape: [2]
>>> # Compute arbitrary quantities that are function of weight and bias
>>> input_vec = torch.tensor([1., -1.])
>>> y1 = weight @ input_vec + bias  # shape: [2]
>>> y2 = (weight ** 2).sum() + (bias ** 2).sum()  # shape: [] (scalar)
>>>
>>> jacobians = jac([y1, y2], [weight, bias])  # shapes: [3, 2, 2], [3, 2]
>>> jacobian_matrices = tuple(J.flatten(1) for J in jacobians)  # shapes: [3, 4], [3, 2]
>>> combined_jacobian_matrix = torch.concat(jacobian_matrices, dim=1)  # shape: [3, 6]
>>> gramian = combined_jacobian_matrix @ combined_jacobian_matrix.T  # shape: [3, 3]
>>> gramian
tensor([[  3.,   0.,  -1.],
        [  0.,   3.,  -3.],
        [ -1.,  -3., 122.]])

The obtained gramian is a symmetric matrix containing the dot products between all pairs of gradients. It’s a strong indicator of gradient norm (the diagonal elements are the squared norms of the gradients) and conflict (a negative off-diagonal value means that the gradients conflict). In fact, most aggregators base their decision entirely on the gramian.

In this case, we can see that the first two gradients (those of y1) both have a squared norm of 3, while the third gradient (that of y2) has a squared norm of 122. The first two gradients are exactly orthogonal (they have an inner product of 0), but they conflict with the third gradient (inner product of -1 and -3).

Example

This example shows how to apply chain rule using the jac_outputs parameter to compute the Jacobian in two steps.

>>> import torch
>>>
>>> from torchjd.autojac import jac
>>>
>>> x = torch.tensor([1., 2.], requires_grad=True)
>>> # Compose functions: x -> h -> y
>>> h = x ** 2
>>> y1 = h.sum()
>>> y2 = torch.tensor([1., -1.]) @ h
>>>
>>> # Step 1: Compute d[y1,y2]/dh
>>> jac_h = jac([y1, y2], [h])[0]  # Shape: [2, 2]
>>>
>>> # Step 2: Use chain rule to compute d[y1,y2]/dx = (d[y1,y2]/dh) @ (dh/dx)
>>> jac_x = jac(h, [x], jac_outputs=jac_h)[0]
>>>
>>> jac_x
tensor([[ 2.,  4.],
        [ 2., -4.]])

This two-step computation is equivalent to directly computing jac([y1, y2], [x]).

Warning

To differentiate in parallel, jac relies on torch.vmap, which has some limitations: it does not work on the output of compiled functions, when some tensors have retains_grad=True or when using an RNN on CUDA, for instance. If you experience issues with jac try to use parallel_chunk_size=1 to avoid relying on torch.vmap.