Recurrent Neural Network (RNN)ΒΆ
When training recurrent neural networks for sequence modelling, we can easily obtain one loss per element of the output sequences. If the gradients of these losses are likely to conflict, Jacobian descent can be leveraged to enhance optimization.
import torch
from torch.nn import RNN
from torch.optim import SGD
from torchjd import backward
from torchjd.aggregation import UPGrad
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
optimizer = SGD(rnn.parameters(), lr=0.1)
aggregator = UPGrad()
inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.
for input, target in zip(inputs, targets):
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
optimizer.zero_grad()
backward(losses, aggregator, parallel_chunk_size=1)
optimizer.step()
Note
At the time of writing, there seems to be an incompatibility between torch.vmap
and
torch.nn.RNN
when running on CUDA (see this issue for more info), so we advise to set the
parallel_chunk_size
to 1
to avoid using torch.vmap
. To improve performance, you can
check whether parallel_chunk_size=None
(maximal parallelization) works on your side.