Intro
I’m excited to share thoad (short for PyTorch High Order Automatic Differentiation), a Python only library that computes arbitrary order partial derivatives directly on a PyTorch computational graph. The package has been developed within a research project at Universidad Pontificia de Comillas (ICAI), and we are considering publishing an academic article in the future that reviews the mathematical details and the implementation design.
At its core, thoad takes a one output, many inputs view of the graph and pushes high order derivatives back to the leaf tensors. Although a 1→N problem can be rewritten as 1→1 by concatenating flattened inputs, as in functional approaches such as jax.jet
or functorch
, thoad’s graph aware formulation enables an optimization based on unifying independent dimensions (especially batch). This delivers asymptotically better scaling with respect to batch size. We compute derivatives vectorially rather than component by component, which is what makes a pure PyTorch implementation practical without resorting to custom C++ or CUDA.
The package is easy to maintain, because it is written entirely in Python and uses PyTorch as its only dependency. The implementation stays at a high level and leans on PyTorch’s vectorized operations, which means no custom C++ or CUDA bindings, no build systems to manage, and fewer platform specific issues. With a single dependency, upgrades and security reviews are simpler, continuous integration is lighter, and contributors can read and modify the code quickly. The UX follows PyTorch closely, so triggering a high order backward pass feels like calling tensor.backward()
. You can install from GitHub or PyPI and start immediately:
In our benchmarks, thoad outperforms torch.autograd
for Hessian calculations even on CPU. See the notebook that reproduces the comparison: https://github.com/mntsx/thoad/blob/master/examples/benchmarks/benchmark\_vs\_torch\_autograd.ipynb.
The user experience has been one of our main concerns during development. thoad is designed to align closely with PyTorch’s interface philosophy, so running the high order backward pass is practically indistinguishable from calling PyTorch’s own backward
. When you need finer control, you can keep or reduce Schwarz symmetries, group variables to restrict mixed partials, and fetch the exact mixed derivative you need. Shapes and independence metadata are also exposed to keep interpretation straightforward.
USING THE PACKAGE
thoad exposes two primary interfaces for computing high-order derivatives:
-
thoad.backward
: a function-based interface that closely resemblestorch.Tensor.backward
. It provides a quick way to compute high-order gradients without needing to manage an explicit controller object, but it offers only the core functionality (derivative computation and storage). -
thoad.Controller
: a class-based interface that wraps the output tensor’s subgraph in a controller object. In addition to performing the same high-order backward pass, it gives access to advanced features such as fetching specific mixed partials, inspecting batch-dimension optimizations, overriding backward-function implementations, retaining intermediate partials, and registering custom hooks.
thoad.backward
The thoad.backward
function computes high-order partial derivatives of a given output tensor and stores them in each leaf tensor’s .hgrad
attribute.
Arguments:
-
tensor
: A PyTorch tensor from which to start the backward pass. This tensor must require gradients and be part of a differentiable graph. -
order
: A positive integer specifying the maximum order of derivatives to compute. -
gradient
: A tensor with the same shape astensor
to seed the vector-Jacobian product (i.e., custom upstream gradient). If omitted, the default is used. -
crossings
: A boolean flag (default=False
). If set toTrue
, mixed partial derivatives (i.e., derivatives that involve more than one distinct leaf tensor) will be computed. -
groups
: An iterable of disjoint groups of leaf tensors. Whencrossings=False
, only those mixed partials whose participating leaf tensors all lie within a single group will be calculated. Ifcrossings=True
andgroups
is provided, a ValueError will be raised (they are mutually exclusive). -
keep_batch
: A boolean flag (default=False
) that controls how output dimensions are organized in the computed gradients.-
When
keep_batch=False
**:** Gradients are returned in a fully flattened form. Concretely, think of the gradient tensor as having:- A single “output” axis that lists every element of the original output tensor (flattened into one dimension).
- One axis per derivative order, each listing every element of the corresponding input (also flattened).
- For an N-th order derivative of a leaf tensor with
input_numel
elements and an output withoutput_numel
elements, the gradient shape is:-
Axis 1: indexes all
output_numel
outputs -
Axes 2…(N+1): each indexes all
input_numel
inputs
-
Axis 1: indexes all
-
When
keep_batch=True
: Gradients preserve both a flattened “output” axis and each original output dimension before any input axes. You can visualize it as:-
Axis 1 flattens all elements of the output tensor (size =
output_numel
). -
Axes 2...(k+1) correspond exactly to each dimension of the output tensor (if the output was shape
(d1, d2, ..., dk)
, these axes have sizesd1
,d2
, ...,dk
). -
Axes (k+2)...(k+N+1) each flatten all
input_numel
elements of the leaf tensor, one axis per derivative order.
-
Axis 1 flattens all elements of the output tensor (size =
- However, if a particular output axis does not influence the gradient for a given leaf, that axis is not expanded and instead becomes a size-1 dimension. This means only those output dimensions that actually affect a particular leaf’s gradient “spread” into the input axes; any untouched axes remain as 1, saving memory.
-
When
-
keep_schwarz
: A boolean flag (default=False
). IfTrue
, symmetric (Schwarz) permutations are retained explicitly instead of being canonicalized/reduced—useful for debugging or inspecting non-reduced layouts.
Returns:
- An instance of
thoad.Controller
wrapping the same tensor and graph.
import torch
import thoad
from torch.nn import functional as F
#### Normal PyTorch workflow
X = torch.rand(size=(10,15), requires_grad=True)
Y = torch.rand(size=(15,20), requires_grad=True)
Z = F.scaled_dot_product_attention(query=X, key=Y.T, value=Y.T)
#### Call thoad backward
order = 2
thoad.backward(tensor=Z, order=order)
#### Checks
## check derivative shapes
for o in range(1, 1 + order):
assert X.hgrad[o - 1].shape == (Z.numel(), *(o * tuple(X.shape)))
assert Y.hgrad[o - 1].shape == (Z.numel(), *(o * tuple(Y.shape)))
## check first derivatives (jacobians)
fn = lambda x, y: F.scaled_dot_product_attention(x, y.T, y.T)
J = torch.autograd.functional.jacobian(fn, (X, Y))
assert torch.allclose(J[0].flatten(), X.hgrad[0].flatten(), atol=1e-6)
assert torch.allclose(J[1].flatten(), Y.hgrad[0].flatten(), atol=1e-6)
## check second derivatives (hessians)
fn = lambda x, y: F.scaled_dot_product_attention(x, y.T, y.T).sum()
H = torch.autograd.functional.hessian(fn, (X, Y))
assert torch.allclose(H[0][0].flatten(), X.hgrad[1].sum(0).flatten(), atol=1e-6)
assert torch.allclose(H[1][1].flatten(), Y.hgrad[1].sum(0).flatten(), atol=1e-6)
thoad.Controller
The Controller
class wraps a tensor’s backward subgraph in a controller object, performing the same core high-order backward pass as thoad.backward
while exposing advanced customization, inspection, and override capabilities.
Instantiation
Use the constructor to create a controller for any tensor requiring gradients:
controller = thoad.Controller(tensor=GO) ## takes graph output tensor
-
tensor
: A PyTorchTensor
withrequires_grad=True
and a non-None
grad_fn
.
Properties
-
.tensor → Tensor
The output tensor underlying this controller. Setter: Replaces the tensor (after validation), rebuilds the internal computation graph, and invalidates any previously computed gradients. -
.compatible → bool
Indicates whether every backward function in the tensor’s subgraph has a supported high-order implementation. IfFalse
, some derivatives may fall back or be unavailable. -
.index → Dict[Type[torch.autograd.Function], Type[ExtendedAutogradFunction]]
A mapping from base PyTorchautograd.Function
classes to thoad’sExtendedAutogradFunction
implementations. Setter: Validates and injects your custom high-order extensions.
Core Methods
.backward(order, gradient=None, crossings=False, groups=None, keep_batch=False, keep_schwarz=False) → None
Performs the high-order backward pass up to the specified derivative order
, storing all computed partials in each leaf tensor’s .hgrad
attribute.
-
order
(int > 0
): maximum derivative order. -
gradient
(Optional[Tensor]
): custom upstream gradient with the same shape ascontroller.tensor
. -
crossings
(bool
, defaultFalse
): IfTrue
, mixed partial derivatives across different leaf tensors will be computed. -
groups
(Optional[Iterable[Iterable[Tensor]]]
, defaultNone
): Whencrossings=False
, restricts mixed partials to those whose leaf tensors all lie within a single group. Ifcrossings=True
andgroups
is provided, a ValueError is raised. -
keep_batch
(bool
, defaultFalse
): controls whether independent output axes are kept separate (batched) or merged (flattened) in stored/retrieved gradients. -
keep_schwarz
(bool
, defaultFalse
): ifTrue
, retains symmetric permutations explicitly (no Schwarz reduction).
.display_graph() → None
Prints a tree representation of the tensor’s backward subgraph. Supported nodes are shown normally; unsupported ones are annotated with (not supported)
.
.register_backward_hook(variables: Sequence[Tensor], hook: Callable) → None
Registers a user-provided hook
to run during the backward pass whenever gradients for any of the specified leaf variables
are computed.
-
variables
(Sequence[Tensor]
): Leaf tensors to monitor. -
hook
(Callable[[Tuple[Tensor, Tuple[Shape, ...], Tuple[Indep, ...]], dict[AutogradFunction, set[Tensor]]], Tuple[Tensor, Tuple[Shape, ...], Tuple[Indep, ...]]]
): Receives the current(Tensor, shapes, indeps)
plus contextual info, and must return the modified triple.
.require_grad_(variables: Sequence[Tensor]) → None
Marks the given leaf variables
so that all intermediate partials involving them are retained, even if not required for the final requested gradients. Useful for inspecting or re-using higher-order intermediates.
.fetch_hgrad(variables: Sequence[Tensor], keep_batch: bool = False, keep_schwarz: bool = False) → Tuple[Tensor, Tuple[Tuple[Shape, ...], Tuple[Indep, ...], VPerm]]
Retrieves the precomputed high-order partial corresponding to the ordered sequence of leaf variables
.
-
variables
(Sequence[Tensor]
): the leaf tensors whose mixed partial you want. -
keep_batch
(bool
, defaultFalse
): ifTrue
, each independent output axis remains a separate batch dimension in the returned tensor; ifFalse
, independent axes are distributed/merged into derivative dimensions. -
keep_schwarz
(bool
, defaultFalse
): ifTrue
, returns derivatives retaining symmetric permutations explicitly.
Returns a pair:
-
Gradient tensor: the computed partial derivatives, shaped according to output and input dimensions (respecting
keep_batch
/keep_schwarz
). -
Metadata tuple
-
Shapes (
Tuple[Shape, ...]
): the original shape of each leaf tensor. -
Indeps (
Tuple[Indep, ...]
): for each variable, indicates which output axes remained independent (batch) vs. which were merged into derivative axes. -
VPerm (
Tuple[int, ...]
): a permutation that maps the internal derivative layout to the requestedvariables
order.
-
Shapes (
Use the combination of independent-dimension info and shapes to reshape or interpret the returned gradient tensor in your workflow.
import torch
import thoad
from torch.nn import functional as F
#### Normal PyTorch workflow
X = torch.rand(size=(10,15), requires_grad=True)
Y = torch.rand(size=(15,20), requires_grad=True)
Z = F.scaled_dot_product_attention(query=X, key=Y.T, value=Y.T)
#### Instantiate thoad controller and call backward
order = 2
controller = thoad.Controller(tensor=Z)
controller.backward(order=order, crossings=True)
#### Fetch Partial Derivatives
## fetch T0 and T1 2nd order derivatives
partial_XX, _ = controller.fetch_hgrad(variables=(X, X))
partial_YY, _ = controller.fetch_hgrad(variables=(Y, Y))
assert torch.allclose(partial_XX, X.hgrad[1])
assert torch.allclose(partial_YY, Y.hgrad[1])
## fetch cross derivatives
partial_XY, _ = controller.fetch_hgrad(variables=(X, Y))
partial_YX, _ = controller.fetch_hgrad(variables=(Y, X))
NOTE. A more detailed user guide with examples and feature walkthroughs is available in the notebook: https://github.com/mntsx/thoad/blob/master/examples/user_guide.ipynb
If you give it a try, I would love feedback on the API, corner cases, and models where you want better plug and play support.
Top comments (0)