Grouping¶
The aggregation can be made independently on groups of parameters, at different granularities. The Gradient Vaccine paper introduces four strategies to partition the parameters:
Together (baseline): one group covering all parameters. Corresponds to the whole_model stategy in the paper.
Per network: one group per top-level sub-network (e.g. encoder and decoder separately). Corresponds to the enc_dec stategy in the paper.
Per layer: one group per leaf module of the network. Corresponds to the all_layer stategy in the paper.
Per tensor: one group per individual parameter tensor. Corresponds to the all_matrix stategy in the paper.
In TorchJD, grouping is achieved by calling jac_to_grad() once per group
after backward() or mtl_backward(), with a dedicated
aggregator instance per group. For Stateful aggregators, each instance
should independently maintains its own state (e.g. the EMA \(\hat{\phi}\) state in
GradVac, matching the per-block targets from the original paper).
Note
The grouping is orthogonal to the choice between
backward() vs mtl_backward(). Those functions
determine which parameters receive Jacobians; grouping then determines how those Jacobians
are partitioned for aggregation.
Note
The examples below use GradVac, but the same pattern applies to
any Aggregator.
1. Together¶
A single Aggregator instance aggregates all shared parameters
together. Cosine similarities are computed between the full task gradient vectors.
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd.aggregation import GradVac
from torchjd.autojac import jac_to_grad, mtl_backward
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
loss_fn = MSELoss()
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
aggregator = GradVac()
for x, y1, y2 in zip(inputs, t1, t2):
features = encoder(x)
loss1 = loss_fn(task1_head(features), y1)
loss2 = loss_fn(task2_head(features), y2)
mtl_backward([loss1, loss2], features=features)
jac_to_grad(encoder.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
2. Per network¶
One Aggregator instance per top-level sub-network. Here the model
is split into an encoder and a decoder; cosine similarities are computed separately within each.
Passing features=dec_out to mtl_backward() causes both sub-networks
to receive Jacobians, which are then aggregated independently.
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd.aggregation import GradVac
from torchjd.autojac import jac_to_grad, mtl_backward
encoder = Sequential(Linear(10, 5), ReLU())
decoder = Sequential(Linear(5, 3), ReLU())
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
optimizer = SGD([*encoder.parameters(), *decoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
loss_fn = MSELoss()
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
encoder_aggregator = GradVac()
decoder_aggregator = GradVac()
for x, y1, y2 in zip(inputs, t1, t2):
enc_out = encoder(x)
dec_out = decoder(enc_out)
loss1 = loss_fn(task1_head(dec_out), y1)
loss2 = loss_fn(task2_head(dec_out), y2)
mtl_backward([loss1, loss2], features=dec_out)
jac_to_grad(encoder.parameters(), encoder_aggregator)
jac_to_grad(decoder.parameters(), decoder_aggregator)
optimizer.step()
optimizer.zero_grad()
3. Per layer¶
One Aggregator instance per leaf module. Cosine similarities are
computed per-layer between the task gradients.
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd.aggregation import GradVac
from torchjd.autojac import jac_to_grad, mtl_backward
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
loss_fn = MSELoss()
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
leaf_layers = [m for m in encoder.modules() if list(m.parameters()) and not list(m.children())]
aggregators = [GradVac() for _ in leaf_layers]
for x, y1, y2 in zip(inputs, t1, t2):
features = encoder(x)
loss1 = loss_fn(task1_head(features), y1)
loss2 = loss_fn(task2_head(features), y2)
mtl_backward([loss1, loss2], features=features)
for layer, aggregator in zip(leaf_layers, aggregators):
jac_to_grad(layer.parameters(), aggregator)
optimizer.step()
optimizer.zero_grad()
4. Per parameter¶
One Aggregator instance per individual parameter tensor. Cosine
similarities are computed per-tensor between the task gradients (e.g. weights and biases of each
layer are treated as separate groups).
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd.aggregation import GradVac
from torchjd.autojac import jac_to_grad, mtl_backward
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
loss_fn = MSELoss()
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
shared_params = list(encoder.parameters())
aggregators = [GradVac() for _ in shared_params]
for x, y1, y2 in zip(inputs, t1, t2):
features = encoder(x)
loss1 = loss_fn(task1_head(features), y1)
loss2 = loss_fn(task2_head(features), y2)
mtl_backward([loss1, loss2], features=features)
for param, aggregator in zip(shared_params, aggregators):
jac_to_grad([param], aggregator)
optimizer.step()
optimizer.zero_grad()