PyTorch Lightning IntegrationΒΆ
To use Jacobian descent with TorchJD in a LightningModule
, you need to turn off
automatic optimization by setting automatic_optimization
to False
and to customize the
training_step
method to make it call the appropriate TorchJD method (backward or mtl_backward).
The following code example demonstrates a basic multi-task learning setup using a
LightningModule
that will call mtl_backward at each training iteration.
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from torch.nn import Linear, ReLU, Sequential
from torch.nn.functional import mse_loss
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from torchjd import mtl_backward
from torchjd.aggregation import UPGrad
class Model(LightningModule):
def __init__(self):
super().__init__()
self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
self.task1_head = Linear(3, 1)
self.task2_head = Linear(3, 1)
self.automatic_optimization = False
def training_step(self, batch, batch_idx) -> None:
input, target1, target2 = batch
features = self.feature_extractor(input)
output1 = self.task1_head(features)
output2 = self.task2_head(features)
loss1 = mse_loss(output1, target1)
loss2 = mse_loss(output2, target2)
opt = self.optimizers()
opt.zero_grad()
mtl_backward(losses=[loss1, loss2], features=features, aggregator=UPGrad())
opt.step()
def configure_optimizers(self) -> OptimizerLRScheduler:
optimizer = Adam(self.parameters(), lr=1e-3)
return optimizer
model = Model()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
dataset = TensorDataset(inputs, task1_targets, task2_targets)
train_loader = DataLoader(dataset)
trainer = Trainer(
accelerator="cpu",
max_epochs=1,
enable_checkpointing=False,
logger=False,
enable_progress_bar=False,
)
trainer.fit(model=model, train_dataloaders=train_loader)
Warning
This will not handle automatic scaling in low-precision settings. There is currently no easy fix.
Warning
TorchJD is incompatible with compiled models, so you must ensure that your model is not compiled.