Engine¶
- class torchjd.autogram.Engine(modules, batch_dim)[source]¶
Engine to compute the Gramian of the Jacobian of some tensor with respect to the direct parameters of all provided modules. It is based on Algorithm 3 of Jacobian Descent For Multi-Objective Optimization but goes even further:
It works for any computation graph (not just sequential models).
It is optimized for batched computations (as long as
batch_dim
is specified).It supports any shape of tensor to differentiate (not just a vector of losses). For more details about this, look at
Engine.compute_gramian()
.
As explained in Section 6 of Jacobian Descent For Multi-Objective Optimization, most
Aggregators
combine the rows of the Jacobian using some weights that depend only on the Gramian of the Jacobian. Because of that, the typical usage of the autogram engine is to directly compute this Gramian, extract weights from it (with aWeighting
), and use those weights to backpropagate the losses. This is equivalent to doing a step of standard Jacobian descent usingtorchjd.autojac.backward()
.- Parameters:
modules (
Iterable
[Module
]) – A collection of modules whose direct (non-recursive) parameters will contribute to the Gramian of the Jacobian.batch_dim (
int
|None
) – If the modules work with batches and process each batch element independently, then many intermediary Jacobians are sparse (block-diagonal), which allows for a substantial memory optimization by backpropagating a squashed Jacobian instead. This parameter indicates the batch dimension of the output tensor, if any.
Example
Train a model using Gramian-based Jacobian descent.
import torch from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD from torchjd.aggregation import UPGradWeighting from torchjd.autogram import Engine # Generate data (8 batches of 16 examples of dim 5) for the sake of the example inputs = torch.randn(8, 16, 5) targets = torch.randn(8, 16) model = Sequential(Linear(5, 4), ReLU(), Linear(4, 1)) optimizer = SGD(model.parameters()) criterion = MSELoss(reduction="none") # Important to use reduction="none" weighting = UPGradWeighting() # Create the engine before the backward pass, and only once. engine = Engine(model.modules(), batch_dim=0) for input, target in zip(inputs, targets): output = model(input).squeeze(dim=1) # shape: [16] losses = criterion(output, target) # shape: [16] optimizer.zero_grad() gramian = engine.compute_gramian(losses) # shape: [16, 16] weights = weighting(gramian) # shape: [16] losses.backward(weights) optimizer.step()
This is equivalent to just calling
torchjd.autojac.backward(losses, UPGrad())
. However, since the Jacobian never has to be entirely in memory, it is often much more memory-efficient, and thus typically faster, to use the Gramian-based approach.Warning
When providing a non-None
batch_dim
, all provided modules must respect a few conditions:They should treat the elements of the batch independently. Most common layers respect this, but for example BatchNorm does not (it computes some average and standard deviation over the elements of the batch).
Their inputs and outputs can be any PyTree (tensor, tuple or list of tensors, dict of tensors, or any nesting of those structures), but each of these tensors must be batched on its first dimension. Transformers and RNNs are thus not supported yet. This is only an implementation issue, so it should be fixed soon (please open an issue if you need extra focus on this).
They should not perform in-place operations on tensors (for instance you should not use
track_running_stats=True
in normalization layers).They should not have side effects during the forward pass (since their forward pass will be called twice, the side effects could be different from what’s expected).
If they have some randomness during the forward pass, they should not have direct trainable parameters. It is, however, perfectly fine for random modules to have child modules that have trainable parameters, so if you have a random module with some direct parameters, a simple fix is to wrap these parameters into a child module.
If you’re building your own architecture, respecting those criteria should be quite easy. However, if you’re using an existing architecture, you may have to modify it to make it compatible with the autogram engine. For instance, you may want to replace BatchNorm2d layers by GroupNorm or InstanceNorm2d layers.
The alternative is to use
batch_dim=None
, but it’s not recommended since it will increase memory usage by a lot and thus typically slow down computation.Note
For maximum efficiency, modules should ideally not contain both direct trainable parameters and child modules, especially if those direct trainable parameters are used before the child modules. You can always wrap those direct trainable parameters into another child module to avoid the slow-down.
- compute_gramian(output)[source]¶
Computes the Gramian of the Jacobian of
output
with respect to the direct parameters of allmodules
.- Parameters:
output (
Tensor
) – The tensor of arbitrary shape to differentiate. The shape of the returned Gramian depends on the shape of this output.- Return type:
Note
This function doesn’t require
output
to be a vector. For example, ifoutput
is a matrix of shape \([m_1, m_2]\), its Jacobian \(J\) with respect to the parameters will be of shape \([m_1, m_2, n]\), where \(n\) is the number of parameters in the model. This is what we call a generalized Jacobian. The corresponding Gramian \(G = J J^\top\) will be of shape \([m_1, m_2, m_2, m_1]\). This is what we call a generalized Gramian. The number of dimensions of the returned generalized Gramian will always be twice that of theoutput
.- A few examples:
0D (scalar)
output
: 0D Gramian (this can be used to efficiently compute the squared norm of the gradient ofoutput
).1D (vector)
output
: 2D Gramian (this is the standard setting of Jacobian descent).2D (matrix)
output
: 4D Gramian (this can be used for Instance-Wise Multi-Task Learning (IWMTL), as each sample in the batch has one loss per task).etc.