Automatic Mixed Precision (AMP)¶
In some cases, to save memory and reduce computation time, you may want to use automatic mixed precision. Since the torch.amp.GradScaler class already works on multiple losses, it’s pretty straightforward to combine TorchJD and AMP. As usual, the forward pass should be wrapped within a torch.autocast context, and as usual, the loss (in our case, the losses) should preferably be scaled with a GradScaler to avoid gradient underflow. The following example shows the resulting code for a multi-task learning use-case.
import torch
from torch.amp import GradScaler
from torch.nn import Sequential, Linear, ReLU, MSELoss
from torch.optim import SGD
from torchjd import mtl_backward
from torchjd.aggregation import UPGrad
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
scaler = GradScaler(device="cpu")
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
with torch.autocast(device_type="cpu", dtype=torch.float16):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
scaled_losses = scaler.scale([loss1, loss2])
optimizer.zero_grad()
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
scaler.step(optimizer)
scaler.update()
Hint
Within the torch.autocast
context, some operations may be done in float16
type. For
those operations, the tensors saved for the backward pass will also be of float16
type.
However, the Jacobian computed by mtl_backward
will be of type float32
, so the .grad
fields of the model parameters will also be of type float32
. This is in line with the
behavior of PyTorch, that would also compute all gradients in float32
type.
Note
torchjd.backward can be similarly combined with AMP.