mtl_backward¶
- torchjd.autojac.mtl_backward.mtl_backward(losses, features, aggregator, tasks_params=None, shared_params=None, retain_graph=False, parallel_chunk_size=None)¶
In the context of Multi-Task Learning (MTL), we often have a shared feature extractor followed by several task-specific heads. A loss can then be computed for each task.
This function computes the gradient of each task-specific loss with respect to its task-specific parameters and accumulates it in their
.grad
fields. Then, it computes the Jacobian of all losses with respect to the shared parameters, aggregates it and accumulates the result in their.grad
fields.- Parameters:
losses (
Sequence
[Tensor
]) – The task losses. The Jacobian matrix will have one row per loss.features (
Union
[Sequence
[Tensor
],Tensor
]) – The last shared representation used for all tasks, as given by the feature extractor. Should be non-empty.aggregator (
Aggregator
) – Aggregator used to reduce the Jacobian into a vector.tasks_params (
Optional
[Sequence
[Iterable
[Tensor
]]]) – The parameters of each task-specific head. Theirrequires_grad
flags must be set toTrue
. If not provided, the parameters considered for each task will default to the leaf tensors that are in the computation graph of its loss, but that were not used to compute thefeatures
.shared_params (
Optional
[Iterable
[Tensor
]]) – The parameters of the shared feature extractor. The Jacobian matrix will have one column for each value in these tensors. Theirrequires_grad
flags must be set toTrue
. If not provided, defaults to the leaf tensors that are in the computation graph of thefeatures
.retain_graph (
bool
) – IfFalse
, the graph used to compute the grad will be freed. Defaults toFalse
.parallel_chunk_size (
int
|None
) – The number of scalars to differentiate simultaneously in the backward pass. If set toNone
, all coordinates oftensors
will be differentiated in parallel at once. If set to1
, all coordinates will be differentiated sequentially. A larger value results in faster differentiation, but also higher memory usage. Defaults toNone
. Ifparallel_chunk_size
is not large enough to differentiate all tensors simultaneously,retain_graph
has to be set toTrue
.
- Return type:
Example
A usage example of
mtl_backward
is provided in Multi-Task Learning (MTL).Note
shared_params
should contain no parameter in common withtasks_params
. The different tasks may have some parameters in common. In this case, the sum of the gradients with respect to those parameters will be accumulated into their.grad
fields.Warning
mtl_backward
relies on a usage oftorch.vmap
that is not compatible with compiled functions. The arguments ofmtl_backward
should thus not come from a compiled model. Check https://github.com/pytorch/pytorch/issues/138422 for the status of this issue.Warning
Because of a limitation of
torch.vmap
, tensors in the computation graph of thefeatures
parameter should not have theirretains_grad
parameter set toTrue
.