mtl_backward

torchjd.autojac.mtl_backward.mtl_backward(losses, features, aggregator, tasks_params=None, shared_params=None, retain_graph=False, parallel_chunk_size=None)

In the context of Multi-Task Learning (MTL), we often have a shared feature extractor followed by several task-specific heads. A loss can then be computed for each task.

This function computes the gradient of each task-specific loss with respect to its task-specific parameters and accumulates it in their .grad fields. Then, it computes the Jacobian of all losses with respect to the shared parameters, aggregates it and accumulates the result in their .grad fields.

Parameters:
  • losses (Sequence[Tensor]) – The task losses. The Jacobian matrix will have one row per loss.

  • features (Union[Sequence[Tensor], Tensor]) – The last shared representation used for all tasks, as given by the feature extractor. Should be non-empty.

  • aggregator (Aggregator) – Aggregator used to reduce the Jacobian into a vector.

  • tasks_params (Optional[Sequence[Iterable[Tensor]]]) – The parameters of each task-specific head. Their requires_grad flags must be set to True. If not provided, the parameters considered for each task will default to the leaf tensors that are in the computation graph of its loss, but that were not used to compute the features.

  • shared_params (Optional[Iterable[Tensor]]) – The parameters of the shared feature extractor. The Jacobian matrix will have one column for each value in these tensors. Their requires_grad flags must be set to True. If not provided, defaults to the leaf tensors that are in the computation graph of the features.

  • retain_graph (bool) – If False, the graph used to compute the grad will be freed. Defaults to False.

  • parallel_chunk_size (int | None) – The number of scalars to differentiate simultaneously in the backward pass. If set to None, all coordinates of tensors will be differentiated in parallel at once. If set to 1, all coordinates will be differentiated sequentially. A larger value results in faster differentiation, but also higher memory usage. Defaults to None.

Return type:

None

Example

A usage example of mtl_backward is provided in Multi-Task Learning (MTL).

Note

shared_params should contain no parameter in common with tasks_params. The different tasks may have some parameters in common. In this case, the sum of the gradients with respect to those parameters will be accumulated into their .grad fields.

Warning

To differentiate in parallel, mtl_backward relies on torch.vmap, which has some limitations: it does not work on the output of compiled functions, when some tensors have retains_grad=True or when using an RNN on CUDA, for instance. If you experience issues with backward try to use parallel_chunk_size=1 to avoid relying on torch.vmap.