Instance-Wise Risk Minimization (IWRM)ΒΆ

This example shows how to use TorchJD to minimize the vector of per-instance losses. This learning paradigm, called IWRM, is multi-objective, as opposed to the usual empirical risk minimization (ERM), which seeks to minimize the average loss.

Hint

A proper definition of IWRM and its empirical results on some deep learning tasks are available in Jacobian Descent For Multi-Objective Optimization.

For the sake of the example, we generate a fake dataset consisting of 8 batches of 16 random input vectors of dimension 10, and their corresponding scalar labels. We train a very simple regression model to retrieve the label from the corresponding input. To minimize the average loss, we use stochastic gradient descent (SGD), where each gradient is computed from the average loss over a batch of data. When minimizing per-instance losses, we use stochastic sub-Jacobian descent, where each Jacobian matrix consists of one gradient per loss. In this example, we use UPGrad to aggregate these matrices.

ERM with SGD

import torch
from torch.nn import (
    MSELoss,
    Sequential,
    Linear,
    ReLU
)
from torch.optim import SGD




X = torch.randn(8, 16, 10)
Y = torch.randn(8, 16, 1)

model = Sequential(
    Linear(10, 5),
    ReLU(),
    Linear(5, 1)
)
loss_fn = MSELoss()

params = model.parameters()
optimizer = SGD(params, lr=0.1)


for x, y in zip(X, Y):
    y_hat = model(x)
    loss = loss_fn(y_hat, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

IWRM with SSJD

import torch
from torch.nn import (
    MSELoss,
    Sequential,
    Linear,
    ReLU
)
from torch.optim import SGD

from torchjd import backward
from torchjd.aggregation import UPGrad

X = torch.randn(8, 16, 10)
Y = torch.randn(8, 16, 1)

model = Sequential(
    Linear(10, 5),
    ReLU(),
    Linear(5, 1)
)
loss_fn = MSELoss(reduction='none')

params = model.parameters()
optimizer = SGD(params, lr=0.1)
A = UPGrad()

for x, y in zip(X, Y):
    y_hat = model(x)
    losses = loss_fn(y_hat, y)
    optimizer.zero_grad()
    backward(losses, params, A)
    optimizer.step()

Note that in both cases, we use the torch.optim.SGD optimizer to update the parameters of the model in the opposite direction of their .grad field. The difference comes from how this field is computed.