Automatic Mixed Precision (AMP)

In some cases, to save memory and reduce computation time, you may want to use automatic mixed precision. Since the torch.amp.GradScaler class already works on multiple losses, it’s pretty straightforward to combine TorchJD and AMP. As usual, the forward pass should be wrapped within a torch.autocast context, and as usual, the loss (in our case, the losses) should preferably be scaled with a GradScaler to avoid gradient underflow. The following example shows the resulting code for a multi-task learning use-case.

import torch
from torch.amp import GradScaler
from torch.nn import Sequential, Linear, ReLU, MSELoss
from torch.optim import SGD

from torchjd import mtl_backward
from torchjd.aggregation import UPGrad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
    *shared_module.parameters(),
    *task1_module.parameters(),
    *task2_module.parameters(),
]
scaler = GradScaler(device="cpu")
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()

inputs = torch.randn(8, 16, 10)  # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1)  # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
    with torch.autocast(device_type="cpu", dtype=torch.float16):
        features = shared_module(input)
        output1 = task1_module(features)
        output2 = task2_module(features)
        loss1 = loss_fn(output1, target1)
        loss2 = loss_fn(output2, target2)

    scaled_losses = scaler.scale([loss1, loss2])
    optimizer.zero_grad()
    mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
    scaler.step(optimizer)
    scaler.update()

Hint

Within the torch.autocast context, some operations may be done in float16 type. For those operations, the tensors saved for the backward pass will also be of float16 type. However, the Jacobian computed by mtl_backward will be of type float32, so the .grad fields of the model parameters will also be of type float32. This is in line with the behavior of PyTorch, that would also compute all gradients in float32 type.

Note

torchjd.backward can be similarly combined with AMP.