<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: mntsx</title>
    <description>The latest articles on DEV Community by mntsx (@mntsx).</description>
    <link>https://dev.to/mntsx</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3476561%2F7b6798a1-f0d1-4897-b9da-a3545f13ffca.jpeg</url>
      <title>DEV Community: mntsx</title>
      <link>https://dev.to/mntsx</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/mntsx"/>
    <language>en</language>
    <item>
      <title>Introducing THOAD, High Order Derivatives for PyTorch Graphs</title>
      <dc:creator>mntsx</dc:creator>
      <pubDate>Tue, 02 Sep 2025 23:01:59 +0000</pubDate>
      <link>https://dev.to/mntsx/introducing-thoad-high-order-derivatives-for-pytorch-graphs-5a1j</link>
      <guid>https://dev.to/mntsx/introducing-thoad-high-order-derivatives-for-pytorch-graphs-5a1j</guid>
      <description>&lt;h2&gt;
  
  
  &lt;strong&gt;PRESENTING THE PACKAGE&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;I’m excited to share &lt;strong&gt;thoad&lt;/strong&gt; (short for PyTorch High Order Automatic Differentiation), a Python only library that computes arbitrary order partial derivatives directly on a PyTorch computational graph. The package has been developed within a research project at &lt;strong&gt;Universidad Pontificia de Comillas - ICAI&lt;/strong&gt;, and we are considering publishing a future academic article reviewing the mathematical details and the implementation design.&lt;/p&gt;

&lt;p&gt;At its core, thoad takes a one output, many inputs view of the graph and pushes high order derivatives back to the leaf tensors. Although a 1→N problem can be rewritten as 1→1 by concatenating flattened inputs, as in functional approaches such as &lt;code&gt;jax.jet&lt;/code&gt; or &lt;code&gt;torch.func&lt;/code&gt;, thoad’s graph aware formulation enables:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Working with smaller &lt;strong&gt;pieced external derivatives&lt;/strong&gt;
&lt;/li&gt;
&lt;li&gt;An optimization based on &lt;strong&gt;unifying independent dimensions&lt;/strong&gt; (e.g. batch).&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;This delivers &lt;strong&gt;asymptotically&lt;/strong&gt; better scaling with respect to order and batch size (respectively).&lt;/p&gt;

&lt;p&gt;Additionally, we compute derivatives with a &lt;em&gt;vectorial&lt;/em&gt; approach rather than component by component, which makes our pure PyTorch implementation possible. Consequently, the implementation stays at a high level, written entirely in Python and using &lt;strong&gt;PyTorch&lt;/strong&gt; as its only dependency. Avoiding custom C++ or CUDA has a very positive impact on the long-term maintainability of the package.&lt;/p&gt;

&lt;p&gt;The package can be installed from &lt;strong&gt;GitHub&lt;/strong&gt; or &lt;strong&gt;PyPI&lt;/strong&gt;:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;GitHub: &lt;a href="https://github.com/mntsx/thoad" rel="noopener noreferrer"&gt;https://github.com/mntsx/thoad&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;PyPI: &lt;a href="https://pypi.org/project/thoad/" rel="noopener noreferrer"&gt;https://pypi.org/project/thoad/&lt;/a&gt;
&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;thoad&lt;/strong&gt;’s Hessian computation outperforms &lt;code&gt;torch.autograd&lt;/code&gt; on GPU by 10–100× for practical networks, remaining close to the performance of the &lt;code&gt;jax.jet&lt;/code&gt; implementation. See the notebook that reproduces the comparison: &lt;a href="https://github.com/mntsx/thoad/blob/master/examples/benchmarks/benchmark%5C_vs%5C_torch%5C_autograd.ipynb" rel="noopener noreferrer"&gt;https://github.com/mntsx/thoad/blob/master/examples/benchmarks/benchmark\_vs\_torch\_autograd.ipynb&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;thoad&lt;/strong&gt; is designed to align closely with PyTorch’s interface philosophy, so running the high order backward pass is practically indistinguishable from calling PyTorch’s own &lt;em&gt;backward&lt;/em&gt;. When you need finer control, you can keep or reduce Schwarz symmetries, group variables to restrict mixed partials, and fetch the exact mixed derivative you need. Shapes and independence metadata are also exposed to keep interpretation straightforward.&lt;/p&gt;

&lt;h2&gt;
  
  
  &lt;strong&gt;USING THE PACKAGE&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;thoad&lt;/strong&gt; exposes two primary interfaces for computing high-order derivatives:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;code&gt;thoad.backward&lt;/code&gt;: a function-based interface that closely resembles &lt;code&gt;torch.Tensor.backward&lt;/code&gt;. It provides a quick way to compute high-order gradients without needing to manage an explicit controller object, but it offers only the core functionality (derivative computation and storage).&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;thoad.Controller&lt;/code&gt;: a class-based interface that wraps the output tensor’s subgraph in a controller object. In addition to performing the same high-order backward pass, it gives access to advanced features such as fetching specific mixed partials, inspecting batch-dimension optimizations, overriding backward-function implementations, retaining intermediate partials, and registering custom hooks.
&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  &lt;strong&gt;thoad.backward&lt;/strong&gt;
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;thoad.backward&lt;/code&gt; function computes high-order partial derivatives of a given output tensor and stores them in each leaf tensor’s &lt;code&gt;.hgrad&lt;/code&gt; attribute.  &lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Arguments&lt;/strong&gt;:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;tensor&lt;/code&gt;: A PyTorch tensor from which to start the backward pass. This tensor must require gradients and be part of a differentiable graph.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;order&lt;/code&gt;: A positive integer specifying the maximum order of derivatives to compute.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;gradient&lt;/code&gt;: A tensor with the same shape as &lt;code&gt;tensor&lt;/code&gt; to seed the vector-Jacobian product (i.e., custom upstream gradient). If omitted, the default is used.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;crossings&lt;/code&gt;: A boolean flag (default=&lt;code&gt;False&lt;/code&gt;). If set to &lt;code&gt;True&lt;/code&gt;, mixed partial derivatives (i.e., derivatives that involve more than one distinct leaf tensor) will be computed.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;groups&lt;/code&gt;: An iterable of disjoint groups of leaf tensors. When &lt;code&gt;crossings=False&lt;/code&gt;, only those mixed partials whose participating leaf tensors all lie within a single group will be calculated. If &lt;code&gt;crossings=True&lt;/code&gt; and &lt;code&gt;groups&lt;/code&gt; is provided, a &lt;em&gt;ValueError&lt;/em&gt; will be raised (they are mutually exclusive).&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;keep_batch&lt;/code&gt;: A boolean flag (default=&lt;code&gt;False&lt;/code&gt;) that controls how output dimensions are organized in the computed gradients.&lt;/li&gt;
&lt;li&gt;
&lt;p&gt;&lt;strong&gt;When &lt;code&gt;keep_batch=False&lt;/code&gt;:&lt;/strong&gt;&lt;br&gt;
The derivative preserves one first flattened "primal" axis, followed by each original partial shape, sorted in differentiation order. Concretelly:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;A single "primal" axis that contains every element of the graph output tensor (flattened into one dimension).&lt;/li&gt;
&lt;li&gt;A group of axes per derivative order, each matching the shape of the respective differentially targeted tensor.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;For an N-th order derivative of a leaf tensor with &lt;code&gt;input_numel&lt;/code&gt; elements and an output with &lt;code&gt;output_numel&lt;/code&gt; elements, the deerivative shape is:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Axis 1:&lt;/strong&gt; indexes all &lt;code&gt;output_numel&lt;/code&gt; outputs&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Axes 2…(sum(Nj)+1):&lt;/strong&gt; each indexes all &lt;code&gt;input_numel&lt;/code&gt; inputs&lt;/li&gt;
&lt;/ul&gt;


&lt;/li&gt;

&lt;li&gt;

&lt;p&gt;&lt;strong&gt;When &lt;code&gt;keep_batch=True&lt;/code&gt;:&lt;/strong&gt;&lt;br&gt;
The derivative shape follows the same ordering as in the previous case, but includes a series of "independent dimensions" immediately after the "primal" axis.&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Axis 1&lt;/strong&gt; flattens all elements of the output tensor (size = &lt;code&gt;output_numel&lt;/code&gt;).&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Axes 2...(k+i)&lt;/strong&gt; correspond to dimensions shared by multiple input tensors and treated independently throughout the graph. These are dimensions that are only operated on element-wise (e.g. batch dimensions).&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Axes (k+i+1)...(k+i+sum(Nj)+1)&lt;/strong&gt; each flatten all &lt;code&gt;input_numel&lt;/code&gt; elements of the leaf tensor, one axis per derivative order.&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;code&gt;keep_schwarz&lt;/code&gt;: A boolean flag (default=&lt;code&gt;False&lt;/code&gt;). If &lt;code&gt;True&lt;/code&gt;, symmetric (Schwarz) permutations are retained explicitly instead of being canonicalized/reduced—useful for debugging or inspecting non-reduced layouts.&lt;/p&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Returns&lt;/strong&gt;:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;An instance of &lt;code&gt;thoad.Controller&lt;/code&gt; wrapping the same tensor and graph.
&lt;/li&gt;
&lt;/ul&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;thoad&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;

&lt;span class="c1"&gt;#### Normal PyTorch workflow
&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;requires_grad&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;Y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;requires_grad&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;Z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;scaled_dot_product_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;#### Call thoad backward
&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
&lt;span class="n"&gt;thoad&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;Z&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;order&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;#### Checks
## check derivative shapes
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;o&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;order&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
   &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hgrad&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;o&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Z&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;numel&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;o&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="nf"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
   &lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hgrad&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;o&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Z&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;numel&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;o&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="nf"&gt;tuple&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;span class="c1"&gt;## check first derivatives (jacobians)
&lt;/span&gt;&lt;span class="n"&gt;fn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;scaled_dot_product_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;J&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;autograd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;functional&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;jacobian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;fn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hgrad&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;atol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;J&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hgrad&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;atol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;## check second derivatives (hessians)
&lt;/span&gt;&lt;span class="n"&gt;fn&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;scaled_dot_product_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;H&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;autograd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;functional&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hessian&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;fn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;H&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hgrad&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;atol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;H&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hgrad&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;flatten&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;atol&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;h3&gt;
  
  
  &lt;strong&gt;thoad.Controller&lt;/strong&gt;
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;Controller&lt;/code&gt; class wraps a tensor’s backward subgraph in a controller object, performing the same core high-order backward pass as &lt;code&gt;thoad.backward&lt;/code&gt; while exposing advanced customization, inspection, and override capabilities.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Instantiation&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Use the constructor to create a controller for any tensor requiring gradients:&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;controller = thoad.Controller(tensor=GO)  ## takes graph output tensor
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;
&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;tensor&lt;/code&gt;: A PyTorch &lt;code&gt;Tensor&lt;/code&gt; with &lt;code&gt;requires_grad=True&lt;/code&gt; and a non-&lt;code&gt;None&lt;/code&gt; &lt;code&gt;grad_fn&lt;/code&gt;.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Properties&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;.tensor → Tensor&lt;/code&gt; The output tensor underlying this controller. &lt;strong&gt;Setter&lt;/strong&gt;: Replaces the tensor (after validation), rebuilds the internal computation graph, and invalidates any previously computed gradients.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;.compatible → bool&lt;/code&gt; Indicates whether every backward function in the tensor’s subgraph has a supported high-order implementation. If &lt;code&gt;False&lt;/code&gt;, some derivatives may fall back or be unavailable.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;.index → Dict[Type[torch.autograd.Function], Type[ExtendedAutogradFunction]]&lt;/code&gt; A mapping from base PyTorch &lt;code&gt;autograd.Function&lt;/code&gt; classes to thoad’s &lt;code&gt;ExtendedAutogradFunction&lt;/code&gt; implementations. &lt;strong&gt;Setter&lt;/strong&gt;: Validates and injects your custom high-order extensions.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Core Methods&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;.backward(order, gradient=None, crossings=False, groups=None, keep_batch=False, keep_schwarz=False) → None&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Performs the high-order backward pass up to the specified derivative &lt;code&gt;order&lt;/code&gt;, storing all computed partials in each leaf tensor’s &lt;code&gt;.hgrad&lt;/code&gt; attribute.&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;order&lt;/code&gt; (&lt;code&gt;int &amp;gt; 0&lt;/code&gt;): maximum derivative order.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;gradient&lt;/code&gt; (&lt;code&gt;Optional[Tensor]&lt;/code&gt;): custom upstream gradient with the same shape as &lt;code&gt;controller.tensor&lt;/code&gt;.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;crossings&lt;/code&gt; (&lt;code&gt;bool&lt;/code&gt;, default &lt;code&gt;False&lt;/code&gt;): If &lt;code&gt;True&lt;/code&gt;, mixed partial derivatives across different leaf tensors will be computed.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;groups&lt;/code&gt; (&lt;code&gt;Optional[Iterable[Iterable[Tensor]]]&lt;/code&gt;, default &lt;code&gt;None&lt;/code&gt;): When &lt;code&gt;crossings=False&lt;/code&gt;, restricts mixed partials to those whose leaf tensors all lie within a single group. If &lt;code&gt;crossings=True&lt;/code&gt; and &lt;code&gt;groups&lt;/code&gt; is provided, a &lt;em&gt;ValueError&lt;/em&gt; is raised.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;keep_batch&lt;/code&gt; (&lt;code&gt;bool&lt;/code&gt;, default &lt;code&gt;False&lt;/code&gt;): controls whether independent output axes are kept separate (batched) or merged (flattened) in stored/retrieved gradients.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;keep_schwarz&lt;/code&gt; (&lt;code&gt;bool&lt;/code&gt;, default &lt;code&gt;False&lt;/code&gt;): if &lt;code&gt;True&lt;/code&gt;, retains symmetric permutations explicitly (no Schwarz reduction).&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;.display_graph() → None&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Prints a tree representation of the tensor’s backward subgraph. Supported nodes are shown normally; unsupported ones are annotated with &lt;code&gt;(not supported)&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;.register_backward_hook(variables: Sequence[Tensor], hook: Callable) → None&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Registers a user-provided &lt;code&gt;hook&lt;/code&gt; to run during the backward pass whenever gradients for any of the specified leaf &lt;code&gt;variables&lt;/code&gt; are computed.&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;variables&lt;/code&gt; (&lt;code&gt;Sequence[Tensor]&lt;/code&gt;): Leaf tensors to monitor.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;hook&lt;/code&gt; (&lt;code&gt;Callable[[Tuple[Tensor, Tuple[Shape, ...], Tuple[Indep, ...]], dict[AutogradFunction, set[Tensor]]], Tuple[Tensor, Tuple[Shape, ...], Tuple[Indep, ...]]]&lt;/code&gt;): Receives the current &lt;code&gt;(Tensor, shapes, indeps)&lt;/code&gt; plus contextual info, and must return the modified triple.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;.require_grad_(variables: Sequence[Tensor]) → None&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Marks the given leaf &lt;code&gt;variables&lt;/code&gt; so that all intermediate partials involving them are retained, even if not required for the final requested gradients. Useful for inspecting or re-using higher-order intermediates.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;.fetch_hgrad(variables: Sequence[Tensor], keep_batch: bool = False, keep_schwarz: bool = False) → Tuple[Tensor, Tuple[Tuple[Shape, ...], Tuple[Indep, ...], VPerm]]&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Retrieves the precomputed high-order partial corresponding to the ordered sequence of leaf &lt;code&gt;variables&lt;/code&gt;.&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;variables&lt;/code&gt; (&lt;code&gt;Sequence[Tensor]&lt;/code&gt;): the leaf tensors whose mixed partial you want.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;keep_batch&lt;/code&gt; (&lt;code&gt;bool&lt;/code&gt;, default &lt;code&gt;False&lt;/code&gt;): if &lt;code&gt;True&lt;/code&gt;, each independent output axis remains a separate batch dimension in the returned tensor; if &lt;code&gt;False&lt;/code&gt;, independent axes are distributed/merged into derivative dimensions.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;keep_schwarz&lt;/code&gt; (&lt;code&gt;bool&lt;/code&gt;, default &lt;code&gt;False&lt;/code&gt;): if &lt;code&gt;True&lt;/code&gt;, returns derivatives retaining symmetric permutations explicitly.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Returns a pair:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Gradient tensor&lt;/strong&gt;: the computed partial derivatives, shaped according to output and input dimensions (respecting &lt;code&gt;keep_batch&lt;/code&gt;/&lt;code&gt;keep_schwarz&lt;/code&gt;).&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Metadata tuple&lt;/strong&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Shapes&lt;/strong&gt; (&lt;code&gt;Tuple[Shape, ...]&lt;/code&gt;): the original shape of each leaf tensor.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Indeps&lt;/strong&gt; (&lt;code&gt;Tuple[Indep, ...]&lt;/code&gt;): for each variable, indicates which output axes remained independent (batch) vs. which were merged into derivative axes.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;VPerm&lt;/strong&gt; (&lt;code&gt;Tuple[int, ...]&lt;/code&gt;): a permutation that maps the internal derivative layout to the requested &lt;code&gt;variables&lt;/code&gt; order.&lt;/li&gt;
&lt;/ul&gt;
&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Use the combination of independent-dimension info and shapes to reshape or interpret the returned gradient tensor in your workflow.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;thoad&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;

&lt;span class="c1"&gt;#### Normal PyTorch workflow
&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;requires_grad&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;Y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;rand&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;15&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;requires_grad&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;Z&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;scaled_dot_product_attention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;query&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;value&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;#### Instantiate thoad controller and call backward
&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;
&lt;span class="n"&gt;controller&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;thoad&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Controller&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;Z&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;controller&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;order&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;crossings&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;#### Fetch Partial Derivatives
## fetch T0 and T1 2nd order derivatives
&lt;/span&gt;&lt;span class="n"&gt;partial_XX&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;controller&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fetch_hgrad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;variables&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;partial_YY&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;controller&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fetch_hgrad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;variables&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;partial_XX&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hgrad&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="k"&gt;assert&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;partial_YY&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;hgrad&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="c1"&gt;## fetch cross derivatives
&lt;/span&gt;&lt;span class="n"&gt;partial_XY&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;controller&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fetch_hgrad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;variables&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;partial_YX&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;controller&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fetch_hgrad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;variables&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;blockquote&gt;
&lt;p&gt;NOTE. A more detailed user guide with examples and feature walkthroughs is available in the notebook: &lt;a href="https://github.com/mntsx/thoad/blob/master/examples/user_guide.ipynb" rel="noopener noreferrer"&gt;https://github.com/mntsx/thoad/blob/master/examples/user_guide.ipynb&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;If you give it a try, I would love feedback on the API, corner cases, and models where you want better plug and play support.&lt;/p&gt;

</description>
      <category>python</category>
      <category>ai</category>
      <category>machinelearning</category>
      <category>opensource</category>
    </item>
  </channel>
</rss>
