Engine

class torchjd.autogram.Engine(modules, batch_dim)[source]

Engine to compute the Gramian of the Jacobian of some tensor with respect to the direct parameters of all provided modules. It is based on Algorithm 3 of Jacobian Descent For Multi-Objective Optimization but goes even further:

  • It works for any computation graph (not just sequential models).

  • It is optimized for batched computations (as long as batch_dim is specified).

  • It supports any shape of tensor to differentiate (not just a vector of losses). For more details about this, look at Engine.compute_gramian().

As explained in Section 6 of Jacobian Descent For Multi-Objective Optimization, most Aggregators combine the rows of the Jacobian using some weights that depend only on the Gramian of the Jacobian. Because of that, the typical usage of the autogram engine is to directly compute this Gramian, extract weights from it (with a Weighting), and use those weights to backpropagate the losses. This is equivalent to doing a step of standard Jacobian descent using torchjd.autojac.backward().

Parameters:
  • modules (Iterable[Module]) – A collection of modules whose direct (non-recursive) parameters will contribute to the Gramian of the Jacobian.

  • batch_dim (int | None) – If the modules work with batches and process each batch element independently, then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates the batch dimension of the output tensor, if any.

Example

Train a model using Gramian-based Jacobian descent.

import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import UPGradWeighting
from torchjd.autogram import Engine

# Generate data (8 batches of 16 examples of dim 5) for the sake of the example
inputs = torch.randn(8, 16, 5)
targets = torch.randn(8, 16)

model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1))
optimizer = SGD(model.parameters())

criterion = MSELoss(reduction="none")  # Important to use reduction="none"
weighting = UPGradWeighting()

# Create the engine before the backward pass, and only once.
engine = Engine(model.modules(), batch_dim=0)

for input, target in zip(inputs, targets):
    output = model(input).squeeze(dim=1)  # shape: [16]
    losses = criterion(output, target)  # shape: [16]

    optimizer.zero_grad()
    gramian = engine.compute_gramian(losses)  # shape: [16, 16]
    weights = weighting(gramian)  # shape: [16]
    losses.backward(weights)
    optimizer.step()

This is equivalent to just calling torchjd.autojac.backward(losses, UPGrad()). However, since the Jacobian never has to be entirely in memory, it is often much more memory-efficient, and thus typically faster, to use the Gramian-based approach.

Warning

When providing a non-None batch_dim, all provided modules must respect a few conditions:

  • They should treat the elements of the batch independently. Most common layers respect this, but for example BatchNorm does not (it computes some average and standard deviation over the elements of the batch).

  • Their inputs and outputs can be any PyTree (tensor, tuple or list of tensors, dict of tensors, or any nesting of those structures), but each of these tensors must be batched on its first dimension. Transformers and RNNs are thus not supported yet. This is only an implementation issue, so it should be fixed soon (please open an issue if you need extra focus on this).

  • They should not perform in-place operations on tensors (for instance you should not use track_running_stats=True in normalization layers).

  • They should not have side effects during the forward pass (since their forward pass will be called twice, the side effects could be different from what’s expected).

  • If they have some randomness during the forward pass, they should not have direct trainable parameters. It is, however, perfectly fine for random modules to have child modules that have trainable parameters, so if you have a random module with some direct parameters, a simple fix is to wrap these parameters into a child module.

If you’re building your own architecture, respecting those criteria should be quite easy. However, if you’re using an existing architecture, you may have to modify it to make it compatible with the autogram engine. For instance, you may want to replace BatchNorm2d layers by GroupNorm or InstanceNorm2d layers.

The alternative is to use batch_dim=None, but it’s not recommended since it will increase memory usage by a lot and thus typically slow down computation.

Note

For maximum efficiency, modules should ideally not contain both direct trainable parameters and child modules, especially if those direct trainable parameters are used before the child modules. You can always wrap those direct trainable parameters into another child module to avoid the slow-down.

compute_gramian(output)[source]

Computes the Gramian of the Jacobian of output with respect to the direct parameters of all modules.

Parameters:

output (Tensor) – The tensor of arbitrary shape to differentiate. The shape of the returned Gramian depends on the shape of this output.

Return type:

Tensor

Note

This function doesn’t require output to be a vector. For example, if output is a matrix of shape \([m_1, m_2]\), its Jacobian \(J\) with respect to the parameters will be of shape \([m_1, m_2, n]\), where \(n\) is the number of parameters in the model. This is what we call a generalized Jacobian. The corresponding Gramian \(G = J J^\top\) will be of shape \([m_1, m_2, m_2, m_1]\). This is what we call a generalized Gramian. The number of dimensions of the returned generalized Gramian will always be twice that of the output.

A few examples:
  • 0D (scalar) output: 0D Gramian (this can be used to efficiently compute the squared norm of the gradient of output).

  • 1D (vector) output: 2D Gramian (this is the standard setting of Jacobian descent).

  • 2D (matrix) output: 4D Gramian (this can be used for Instance-Wise Multi-Task Learning (IWMTL), as each sample in the batch has one loss per task).

  • etc.