mtl_backward¶
- torchjd.autojac.mtl_backward(tensors, features, grad_tensors=None, tasks_params=None, shared_params=None, retain_graph=False, parallel_chunk_size=None)[source]¶
In the context of Multi-Task Learning (MTL), we often have a shared feature extractor followed by several task-specific heads. A loss can then be computed for each task.
This function computes the gradient of each task-specific tensor with respect to its task-specific parameters and accumulates it in their
.gradfields. It also computes the Jacobian of all tensors with respect to the shared parameters and accumulates it in their.jacfields. These Jacobians have one row per task.If the
tensorsare non-scalar,mtl_backwardrequires some initial gradients ingrad_tensors. This allows to composemtl_backwardwith some other function computing the gradients with respect to the tensors (chain rule).- Parameters:
tensors (
Sequence[Tensor]) – The task-specific tensors. If these are scalar (e.g. the losses produced by every task), nograd_tensorsare needed. If these are non-scalar tensors, providing somegrad_tensorsis necessary.features (
Sequence[Tensor] |Tensor) – The last shared representation used for all tasks, as given by the feature extractor. Should be non-empty.grad_tensors (
Sequence[Tensor] |None) – The initial gradients to backpropagate, analog to thegrad_tensorsparameter oftorch.autograd.backward(). If any of thetensorsis non-scalar,grad_tensorsmust be provided, with the same length and shapes astensors. Otherwise, this parameter is not needed and will default to scalars of 1.tasks_params (
Sequence[Iterable[Tensor]] |None) – The parameters of each task-specific head. Theirrequires_gradflags must be set toTrue. If not provided, the parameters considered for each task will default to the leaf tensors that are in the computation graph of its tensor, but that were not used to compute thefeatures.shared_params (
Iterable[Tensor] |None) – The parameters of the shared feature extractor. Theirrequires_gradflags must be set toTrue. If not provided, defaults to the leaf tensors that are in the computation graph of thefeatures.retain_graph (
bool) – IfFalse, the graph used to compute the grad will be freed. Defaults toFalse.parallel_chunk_size (
int|None) – The number of scalars to differentiate simultaneously in the backward pass. If set toNone, all coordinates oftensorswill be differentiated in parallel at once. If set to1, all coordinates will be differentiated sequentially. A larger value results in faster differentiation, but also higher memory usage. Defaults toNone.
- Return type:
Example
A usage example of
mtl_backwardis provided in Multi-Task Learning (MTL).Note
shared_paramsshould contain no parameter in common withtasks_params. The different tasks may have some parameters in common. In this case, the sum of the gradients with respect to those parameters will be accumulated into their.gradfields.Warning
To differentiate in parallel,
mtl_backwardrelies ontorch.vmap, which has some limitations: it does not work on the output of compiled functions, when some tensors haveretains_grad=Trueor when using an RNN on CUDA, for instance. If you experience issues withbackwardtry to useparallel_chunk_size=1to avoid relying ontorch.vmap.