DEV Community

Kurt
Kurt

Posted on • Originally published at getcode.substack.com on

Beyond Backpropagation - Higher Order, Forward and Reverse-mode Automatic Differentiation for Tensorken

This post describes how I added automatic differentiation to Tensorken. Tensorken is my attempt to build a fully featured yet easy-to-understand and hackable implementation of a deep learning library in Rust. It takes inspiration from the likes of PyTorch, Tinygrad, and JAX.

Tensorken's approach to automatic differentiation (or AD) is heavily inspired by JAX. Like JAX, Tensorken supports higher-order derivatives - besides the first derivative, it can calculate the second, third, and so on. Tensorken supports both forward and reverse-mode AD, and can arbitrarily compose the two. Finally, thanks to good fundamentals explained in the previous two posts (part 1 and part 2), Tensorken can compute derivatives on the CPU or GPU.

All code for this post is in the Tensorken repository, tagged v0.3.

Four turtles on top of each other on a forest background
There's a "it's turtles all the way down" reference somewhere in this post, and only then will this image make sense. Generated by Lexica's Aperture model.

Previously in Tensors From Scratch: neural networks, matrix multiplication, and GPU acceleration

Modern neural networks, for example, large language models (LLMs) like OpenAI's ChatGPT and GPT-4, Microsoft's Bing, Google's Bard, and Anthropic's Claude, are powered by tensors. Tensors are multi-dimensional arrays augmented with operations that execute efficiently on modern hardware, most notably GPUs.

To understand all that, I am building a neural network library like PyTorch or JAX, from the ground up in Rust. These libraries consist of:

  1. A tensor library, to provide efficient operations to slice, dice, and apply bulk operations to tensors.

  2. Accelerators, to accelerate tensor operations on the GPU.

  3. Automatic differentiation, to train neural networks via gradient descent.

  4. Neural network building blocks, to simplify using common activation functions and layers.

In the first post, I focussed on the tensor library. I described almost twenty fundamental tensor operations and abstracted them in a Rust trait called RawTensor. RawTensor had a single implementation, CpuRawTensor which executes tensor computations on the CPU. In the second post, I implemented RawTensor again in WgpuRawTensor to execute on the GPU using wgpu, Rust's implementation of WebGPU. We dove into the nitty-gritty of GPU programming in general and wgpu in particular.

This third part of the series describes how to add automatic differentiation to Tensorken. Automatic differentiation (AD) is a technique to compute derivatives of tensor computations, without programmer intervention. AD is crucial because neural networks are trained via gradient descent, which relies on the efficient calculation of derivatives.

How to train your neural network

Let's sketch how to train a neural network to emphasize how important AD is for deep learning.

First, gather training data, and lots of it. Training data are lots of input-expected output pairs. The input examples are encoded as numbers and aggregated in a tensor πš‡. The outputs go in a tensor 𝚈. Think of 𝚈 as the correct predictions for the inputs πš‡. For a language model, each example in πš‡ could be a sequence of words, and 𝚈 the next word, encoded as numbers. (How to encode text as numbers is an interesting problem that's not relevant to this story.)

Second, decide on the architecture of your neural network. A neural network consists of tensors πš†α΅’ that contain the parameters of the network. That's what you download when you get a model's weights. The architecture determines how many parameters we have and how we combine the input πš‡ with parameters Wα΅’ to obtain an output 𝚈'. Whatever the architecture is, we can execute it to predict 𝚈':

𝚈' = 𝚏(πš‡, πš†α΅’).

I'm simplifying - researchers distinguish weights πš† and biases b, but in the end, they're both part of the trainable parameters so I'm just lumping them together in the πš†s.

Third, using the expected 𝚈 and the prediction 𝚈', calculate the loss 𝙻. The loss is a single number that is high when the prediction is bad, and low when it is good. The loss is calculated by comparing the network's prediction 𝚈' with the expected output 𝚈:

𝙻 = πš•(𝚈, 𝚈').

Fourth, calculate the gradient 𝙢 of the loss. The scalar loss value 𝙻 is a function of πš‡, πš†α΅’, and 𝚈. Imagine the loss function as describing a (highly dimensional!) landscape. Training the network to improve its predictions means changing the parameters Wα΅’ to make the loss small. We'd like to know how we should change the parameters to achieve that.

Now is when the gradient comes in. Going back to the landscape analogy, to make the loss smaller we'd like to know the best direction to "move" in to go "down" - that is, from the current value of the parameters, find the direction with the highest slope. If you remember some calculus, the derivative of a function at a point is that slope. So, to calculate the gradient, we calculate the loss function's derivative with respect to each parameter πš†α΅’. In other words, we'll have a number for each parameter that tells us how to change that parameter to make the loss smaller.

Fifth, update the parameters using the gradient. There are many ways of doing this. The simplest is to multiply the gradient with a small number Ο΅ and subtract it from the parameters:

Wα΅’ <- Wα΅’ βˆ’ ϡ𝙢ᡒ.

That's one training step done! Your neural network just got a tiny bit better. Now repeat from step 2 until you've had enough. You can stop when the loss becomes small enough, when it stops changing for some number of iterations, or when your AWS bill exceeds the budget.

Tensorken can already do almost all of those steps. Running a network, calculating a loss, and updating the parameters amounts to applying tensor operations. What's missing is calculating the gradient via the loss function's derivative. In the olden days, people would calculate the derivative of the network by hand, symbolically, and then implement it manually. Clearly tedious and error-prone, not to mention limiting the complexity and size of the networks. Modern neural network libraries calculate a function's output and its derivative without programmer intervention using the miracle of automatic differentiation.

AD is a vast and intricate topic. For a (much) longer primer on the basics, see my earlier post. If you are unfamiliar with AD I encourage you to read it or any of the AD primers in the links.

The following section demonstrates Tensorken's AD capabilities and interface via small examples. Then we'll dive into implementation details, but I'll stay away from the detailed mechanics of AD since that's already covered elsewhere. Instead, I'll focus on how Tensorken implements higher-order, mixed-mode, JAX-style AD as an elegant and minimal Rust library.

The Autodiff Cookbook in Tensorken

To demonstrate Tensorken's AD capabilities, I translated a significant part of JAX's Autodiff Cookbook to Tensorken. I reproduced and edited part of the original text here. JAX's license is Apache 2.0, so I hope this does not incur the wrath of Google. The titles in this section are similar to the ones in JAX's cookbook, in case you want to compare. The full example code is in jax_autodiff_cookbook.rs.

Before we begin - Tensorken runs on the CPU if you create tensors via the Cpu32 type alias and on the GPU via Wgpu32. (The 32 is because they work with 32-bit floating point numbers.) To make it easy to switch, I'll use the Tr type alias throughout:

type Tr = Cpu32; // or Wgpu32
Enter fullscreen mode Exit fullscreen mode

Gradients

You can differentiate a function using grad1. The 1 indicates the number of arguments of the function - a poor man's variadic arguments. In the text, I'll sometimes refer to the family of grad1, grad2 functions as grad. In the code, I'll use the function with the correct number of arguments.

To start with, we'll use a simple scalar function - a function that takes a single number and returns a single number:

let p = Tr::scalar(2.0);
let df = grad1(|x| x.tanh(), &p);

> df: [0.07065082]
Enter fullscreen mode Exit fullscreen mode

In Tensorken, all arguments must be a tensor Tr - it doesn't support mixed tensors and scalar numbers. To turn a number into a tensor we first use Tr::scalar. It makes a tensor with shape [1].

grad1 takes a function of one argument 𝚏 and evaluates βˆ‡πš(πš™), the derivative of 𝚏 at a given point πš™. You can think of βˆ‡ as a higher-order function that takes a differentiable function 𝚏 and produces a function that evaluates the derivative.

Pronouncing βˆ‡: I say "grad", I've heard people say "del", and the symbol's Unicode name is "nabla".

Similarly, if you have a Rust function f that evaluates the mathematical function 𝚏, then grad(f, p) computes the value βˆ‡πš(πš™).

Unlike JAX, Tensorken does not directly expose βˆ‡πš as a first-class function, mostly because I had a hard time accomplishing that in Rust and staying sane! It required returning a closure from grad(f) so you can write grad(f)(p), but satisfying the compiler proved difficult. So far this hasn't been a constraint in practice.

Like JAX, Tensorken does support applying grad to functions that themselves call grad to calculate higher-order derivatives:

let ddf = grad1(|x| grad1(|x| x.tanh(), x), &p);
let dddf = grad1(|x| grad1(|x| grad1(|x| x.tanh(), x), x), &p);

> ddf: [-0.13621868]
> dddf: [0.25265408]
Enter fullscreen mode Exit fullscreen mode

Let’s try computing gradients with grad in a linear logistic regression model. In other words, a simple neural network with one neuron. First, the setup:

// Outputs probability of a label being true.
fn predict<'t, T>(w: &'t T, b: &'t T, inputs: &T) -> T
where
    T: TensorLike<'t>,
{
    (inputs.dot(w) + b).sigmoid()
}
Enter fullscreen mode Exit fullscreen mode

The function predict encodes the architecture of our toy model. Its parameters are a vector w and a scalar b, for weights and bias. As you can see, we're multiplying the weights with the inputs and adding the bias. Then we use sigmoid to squish the output values in the [0, 1] interval. This model predicts the probability of an outcome based on some input measurements.

Why is this equivalent to a neural network with a single neuron? Say the vector w has three elements - three weights. We thus have three inputs as well, in inputs. The function dot multiplies each input with its corresponding weight, and then adds them up. The bias b in the neuron analogy is typically a negative number, which represents a threshold that inputs.dot(w) must exceed to "activate" the neuron.

All arguments are tensors, but are represented by a generic argument T. The type needs to be generic so automatic differentiation can work. We'll see later why. T: TensorLike is a handy constraint to make tensor operations like dot, +, and sigmoid available on T. You'll see the TensorLike constraint often when using Tensorken's AD: to make functions differentiable, replace concrete Tensor types with T: TensorLike.

Let's run the model.

// Build a toy dataset.
// These are four measurements of some unspecified variable, one in each row.
let inputs = Tr::new(
    &[4, 3],
    &[
        0.52, 1.12, 0.77, //
        0.88, -1.08, 0.15, //
        0.52, 0.06, -1.30, //
        0.74, -2.49, 1.39,
    ],
);
// These are four observed outcomes, one for each row in the input.
let targets = Tr::new(&[4], &[1.0, 1.0, 0.0, 1.0]);

// Initialize the parameters w and b randomly
let key = 0;
let mut rng = StdRng::seed_from_u64(key);
let w = Tr::randn(&[3], &mut rng);
let b = Tr::randn(&[1], &mut rng);

let prediction = predict(&w, &b, &inputs);

> prediction: [0.4059896 0.37711427 0.9770815 0.007901279]
Enter fullscreen mode Exit fullscreen mode

The inputs could be "changes in temperature observed on three consecutive days" and the targets could be "temperature went up or down on the next day". We're then training a model that predicts the probability of the temperature going up given three days' changes in temperature.

Since we initialized the model randomly, its prediction is random. We got unlucky: if you compare targets (what we want) with prediction (what we have) there is a big difference. The 3rd and 4th predictions are especially bad, almost the exact opposite of the training data.

To improve our model, we first need to quantify how crap the model is via a loss function.

// Training loss is the negative log-likelihood of the training examples.
fn loss<'t, T>(w: &'t T, b: &'t T, inputs: &T, targets: &'t T) -> T
where
    T: TensorLike<'t>,
    for<'s> &'s T: TensorLikeRef<T>,
{
    let prediction = predict(w, b, inputs);
    // ones_like makes a tensor of the same shape with all values equal to 1.
    let label_probs = &prediction * targets
        + (&prediction.ones_like() - &prediction) * (targets.ones_like() - targets);
    -label_probs.log().sum(&[0])
}

let l = loss(&w, &b, &inputs, &targets);

> loss: [10.4931755]
Enter fullscreen mode Exit fullscreen mode

This loss function is negative log-likelihood. You can intuit why it works: prediction is "compared" with targets in label_probs. It contains a high value for predictions that are close to the target. We then take the log of each, which exaggerates its value: the logarithm is -infinity when label_probs is zero. Since the logarithm is negative, we negate it to get a positive number. Then we take the sum of the vector so we have a single positive loss number that is high when the model is doing badly, and low when it's making good predictions.

Now we can improve the model by adjusting its weights and biases. We use grad to differentiate the loss function with respect to the parameters w and b:

// Differentiate loss wrt weights
let w_grad = grad1(
    |w| {
        loss(
            w,
            &Reverse::lift(&b),
            &Reverse::lift(&inputs),
            &Reverse::lift(&targets),
        )
    },
    &w,
);
print!("w_grad: {w_grad}");

// Differentiate loss wrt bias
let b_grad = grad1(
    |b| {
        loss(
            &Reverse::lift(&w),
            b,
            &Reverse::lift(&inputs),
            &Reverse::lift(&targets),
        )
    },
    &b,
);

> w_grad: [-1.0830948 2.5363755 -3.2000453]
> b_grad: [-1.2319121]
Enter fullscreen mode Exit fullscreen mode

To make the types work out, we need to Reverse::lift all the arguments to loss we do NOT want to differentiate. They are treated as constants. The type is called Reverse because Tensorken uses reverse mode AD in this case. The Reverse type reveals why we need to make the arguments to loss and predict generic: the grad function, while taking a plain Tr type as the second argument, passes Reverse<Tr> to the closure. So the function f can be called with Tr, Reverse<Tr>, or other types we'll see later.

Here's the simplified signature for grad1.We'll get to the full signature later:

pub fn grad1<F>(f: F, at: &Tr) -> Tr where F: Fn(&Reverse<Tr>) -> Reverse<Tr>
Enter fullscreen mode Exit fullscreen mode

Briefly, Reverse is a wrapper to interpret tensor operations so they calculate the derivative along with the main result. In this example, it'll run a different dot, +, and sigmoid compared to calling loss with plain tensors of type Tr.

Calling the loss function twice is not ideal - we're doing twice the work. We can also calculate the gradients with respect to both w and b at the same time, using grad2.

let (w_grad, b_grad) = grad2(
    |w, b| loss(w, b, &Reverse::lift(&inputs), &Reverse::lift(&targets)),
    &w,
    &b,
);

> w_grad: [-1.0830948 2.5363755 -3.2000453]
> b_grad: [-1.2319121]
Enter fullscreen mode Exit fullscreen mode

Finally, let's do a single training iteration and check if that improves our model.

// Update parameters
let new_w = &w - &w_grad;
let new_b = &b - &b_grad;

// Predict
let new_prediction = predict(&new_w, &new_b, &inputs);
let new_loss = loss(&new_w, &new_b, &inputs, &targets);

> new_prediction: [0.7384342 0.99262685 0.7747804 0.9996524]
> new_loss: [1.8016509]
Enter fullscreen mode Exit fullscreen mode

A massive improvement - we're now only 1.8 crap, down from 10.5!

Evaluate a function and its gradient using value_and_grad

In a real training run, we'd do the above in a loop while keeping an eye on the loss to see when to stop. Again loss is called twice: once inside grad and once outside. Luckily, we don't have to. Another convenient family of functions is value_and_grad to efficiently compute a function and its gradient.

let (loss_value, (w_grad, b_grad)) = value_and_grad2(
    |w, b| loss(w, b, &Reverse::lift(&inputs), &Reverse::lift(&targets)),
    &w,
    &b,
);

> loss: [10.4931755]
> w_grad: [-1.0830948 2.5363755 -3.2000453]
> b_grad: [-1.2319121]
Enter fullscreen mode Exit fullscreen mode

Checking against numerical differences

Our loss improved, which is a good indication that things work. To gain confidence we can compare Tensorken's derivatives with finite differences.

// step size for finite difference
let eps = Tr::scalar(1e-4);
let half_eps = &eps / Tr::scalar(2.);
let b_grad_numerical = (loss(&w, &(&b + &half_eps), &inputs, &targets)
    - loss(&w, &(&b - &half_eps), &inputs, &targets))
    / &eps;

> b_grad_numerical [-1.2207031]
> b_grad_autodiff [-1.2319121]
Enter fullscreen mode Exit fullscreen mode

Close enough.

Jacobians using jacfwd and jacrev

Ignoring bias b for now, the loss function is a function of three parameters, represented as a single tensor w with three elements. It has a single scalar output, represented as a tensor with a single element. Taking the gradient of this function results in a vector of three elements, the sensitivity of the loss to each parameter. This picture becomes more complicated if there is more than one output parameter. grad still gives an answer, but what does it mean?

let deriv = grad1(
    |w| predict(w, &Reverse::lift(&b), &Reverse::lift(&inputs)),
    &w,
);

> deriv: [0.34956074 -0.0017646346 0.20271438]
Enter fullscreen mode Exit fullscreen mode

Remember that predict returns a vector with four elements, and the input w is a vector with three elements. We get a vector with three sensitivities - one for each input. But the sensitivity of which output? There are four. As we'll check below, grad returns the sum of the sensitivity of all outputs. That's typically not what we want: we'd like to disaggregate the sensitivities.

The usual approach is to represent the sensitivity of each output with respect to each input as a matrix, called the Jacobian. In this case, a 4 by 3 matrix - number of outputs by number of inputs. Tensorken can compute Jacobians, in forward and reverse mode using jacfwd and jacrev:

let J = jacfwd(
    |w| predict(w, &Forward::lift(&b), &Forward::lift(&inputs)),
    &w,
);

> jacfwd result, with shape [4, 3]
β”Œ ┐
β”‚ 0.12540425 0.2701015 0.18569478 β”‚
β”‚ 0.20671119 -0.25369102 0.03523486 β”‚
β”‚ 0.01164451 0.0013435973 -0.029111274 β”‚
β”‚ 0.0058007482 -0.019518733 0.010895999 β”‚
β”” β”˜

let J = jacrev(
    |w| predict(w, &Reverse::lift(&b), &Reverse::lift(&inputs)),
    &w,
);

> jacrev result, with shape [4, 3]
β”Œ ┐
β”‚ 0.12540427 0.27010152 0.18569478 β”‚
β”‚ 0.20671119 -0.25369102 0.03523486 β”‚
β”‚ 0.01164451 0.0013435973 -0.029111274 β”‚
β”‚ 0.005800748 -0.019518731 0.010895998 β”‚
β”” β”˜
Enter fullscreen mode Exit fullscreen mode

These two functions compute the same values (up to machine precision), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for "wide" Jacobian matrices. For matrices that are near-square, jacfwd probably has an edge over jacrev.

We can now check that grad computed the sum of the sensitivity of all four outputs:

&J.sum(&[0])

> [0.34956074 -0.0017646346 0.20271438]
Enter fullscreen mode Exit fullscreen mode

Using a composition of javfwd and jacrev gives us a way to compute dense Hessian matrices. Hessian matrices contain all the second derivatives.

let hessian = jacfwd(
    |w| {
        jacrev(
            |w| {
                predict(
                    w,
                    &Reverse::lift(&Forward::lift(&b)),
                    &Reverse::lift(&Forward::lift(&inputs)),
                )
            },
            w,
        )
    },
    &w,
);
println!("hessian with shape {:?}", hessian.shape());

> hessian shape [4, 3, 3]
Enter fullscreen mode Exit fullscreen mode

Why this shape? We start with a function f:𝙽→𝙼. Traditionally, we'd write 𝚏:ℝⁿ→ℝᡐ, but that there are 𝙽 inputs and 𝙼 outputs is more important than that we're talking about real numbers, so I'll omit the ℝ from now on.

At a point 𝚑 ∈ 𝙽 we expect to get the shapes

  • 𝚏(𝚑) ∈ 𝙼, the value of 𝚏 at 𝚑,

  • βˆ‚πš(𝚑) ∈ 𝙼 Γ— 𝙽, the Jacobian matrix at 𝚑,

  • βˆ‚Β²πš(𝚑) ∈ 𝙼 Γ— 𝙽 Γ— 𝙽, the Hessian at 𝚑,

and so on.

To implement hessian we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because, in the inner Jacobian computation, we’re often differentiating a function with a wide Jacobian (maybe like a loss function 𝚏:𝙽→1), while in the outer Jacobian computation, we’re differentiating a function with a square Jacobian (since βˆ‡πš:𝙽×𝙽), which is where forward-mode wins out.

Note we now need to lift the inputs twice to make Rust's type checker happy.

That concludes the tour of Tensorken's AD capabilities. It packs a lot of punch - now let's see how to fit it in a small package.

A tale of two functions

All AD functions like jacfwd, jacrev, and grad are implemented in terms of two function-type pairs: jvp with Forward, and vjp with Reverse. JVP stands for Jacobian-vector product, and VJP stands for Vector-Jacobian product. These functions are directly inspired by JAX. To explain their names, we need some math background that deserves a standalone post. If you can't wait, refer to this section in JAX's Autodiff Cookbook.

I'll now introduce jvp and vjp, and the beginnings of how AD works in Tensorken. I assume some background knowledge about AD, in particular AD for scalar functions. See my earlier post for a primer.

From scalars to tensors

Forward AD on scalar functions works by replacing operators and functions on numbers with versions that operate on a dual number - a (f32, f32) tuple. The first element is the primal, which the function computes without AD. The second is the derivative, or tangent. Operations on dual numbers are straightforward:

  • apply the operation to the primal(s), and

  • apply differentiation rules to the tangent(s).

For example, multiplication on dual numbers is:

(p₁, t₁) . (pβ‚‚, tβ‚‚) = (p₁.pβ‚‚, p₁.tβ‚‚ + pβ‚‚.t₁)

Reverse mode is more involved. The primal computation is identical, but instead of calculating the tangent alongside the primal, we collect a trace - essentially a stack of operations. A reverse pass through the trace calculates the derivatives.

Exactly how these operations are replaced is a concern for the implementation. Common methods are code transformation in the compiler, code generation, and operator overloading. Tensorken uses trait-based overloading.

Forward mode has little extra memory requirements beyond bringing the tangent along for the ride, while reverse mode needs to keep a trace that's as long as the computation is deep. As a result, for scalar-to-scalar functions, forward mode is more efficient.

That situation changes if we consider functions from many scalars to one, or vice versa. One extreme is a function that takes a single input and computes n outputs. That's great for forward mode: in one execution of the function on dual numbers, we'll have both the primal result and the derivative - or in other words the sensitivity of each output to a small change in the single input.

However, a function that takes many inputs and has a single output is efficient only in reverse mode. In forward mode, we'd need as many executions of the function as there are inputs - we'd have to pass 1 as tangent for each input separately. In reverse mode, we still need the extra memory for the trace, but one forward pass for the primal and one backward pass for the partial derivatives is all we need.

The good news is that if you understand this, nothing much changes if we allow tensors instead of scalars. After all, a tensor is a container of scalars, and operations on tensors can be broken down into operations on scalars. That's not how we want to implement them though! Bulk operations are where the performance is at.

For forward mode, we'll overload tensor operations to propagate a "dual tensor", a tuple of a primal and a tangent tensor. For reverse mode, we'll build up a trace of tensor operations in the forward pass and get the tangent tensors from a backward pass.

One difference with scalar AD is that we need to take the shape of tensors into account. Besides arithmetic operations like addition and multiplication, we also need to figure out differentiation rules for sum, reshape, and others, which affect the shape of both primal and derivative tensors.

JVP for forward-mode AD

Here's the signature of jvp:

pub fn jvp1<T: Diffable + Clone, F>(f: F, at: &T, tangent: &T) -> (T, T)
where
    for<'a> F: Fn(&'a Forward<T>) -> Forward<T>,
Enter fullscreen mode Exit fullscreen mode

As the 1 suffix indicates, this version takes a single primal tensor at and a single tangent tensor tangent. It evaluates the primal and tangent of the function f and returns them as a tuple.

To understand why this signature makes sense, think of AD as a program transformation. Without AD we'd write programs that boil down to:

let p1 = f1(x);
let p2 = f2(p1);
...
Enter fullscreen mode Exit fullscreen mode

With forward AD and jvp we can rewrite them as:

let (p1,t1) = jvp(f1, x, x.ones_like());
let (p2,t2) = jvp(f2, p1, t1);
...
Enter fullscreen mode Exit fullscreen mode

That illustrates how programs that compose functions can be transformed into programs that compose jvp-wrapped functions.

Importantly, at and tangent must have the same shape, and the two tensors in the output tuple have the same shape. f computes out.0 from at, and jvp additionally computes out.1 from at and tangent.

jvp works for any tensor-like type that implements Diffable. Diffable is the foundational trait that defines Tensorken's primitive, differentiable tensor operations. Higher-level operations like matmul are built on these. Keeping the tensor generic in jvp allows it to work with any Diffable implementation - something we'll use when doing higher-order AD. We'll come back to the details soon.

VJP for backward AD

Here is the signature for vjp:

pub fn vjp1<'b, 't, T: Diffable + Clone + 't, F>(f: F, at: &T) -> (T, PullBack<'t, T>)
where
    for<'a> F: Fn(&'a Reverse<'a, 't, T>) -> Reverse<'a, 't, T>,
Enter fullscreen mode Exit fullscreen mode

It looks a bit different because reverse mode has a backward pass. What's the same are the differentiable function f: F and the primal input at. vjp calls f with Reverse wrapping the input T. Since reverse mode needs two passes, vjp only returns the primal directly.

PullBack (a term from differential geometry) is a named struct that executes the backward pass. It takes a cotangent, a tensor in the shape of the output of f, and calculates the tangent, a tensor in the shape of the input of f.

impl<T: Diffable + Clone> PullBack<'_, T> {
    pub fn call(&self, cotangent: &T) -> T
}
Enter fullscreen mode Exit fullscreen mode

It's all backward! But that's why reverse mode AD is more efficient if you have the right tensor shape.

A short note on why jvp and vjp have different signatures. On the one hand, we could re-write jvp to return a PushForward struct with a call function that works similarly to jvp's PullBack. However, that would require keeping a trace of the operations around so users can call multiple times with different tangents. That jeopardizes the memory efficiency of forward mode. The ability to re-execute the differentiating pass with different tangents does not offset the added memory usage.

We could also write vjp with a signature like jvp by making the PullBack internal and calling at the end. In reverse mode, we have to expend the memory anyway, so we might as well make it available to the user for potential reuse.

Interpreters all the way down

We're now at the point where we can dive into the code, and it's interpreters all the way down.

Before AD, Tensorken's core was the RawTensor trait, with implementations for the CPU and the GPU. It's useful to think of this trait as the definition of a language for primitive tensor operations, and implementations of the trait as interpreters of that language. Interpreters don't necessarily have to produce a tensor - for debugging and testing a pretty-printing interpreter for RawTensor is useful:

impl RawTensor for String {
    type Elem = f32;

    fn exp(&self) -> Self {
        format!("{self}.exp()")
    }

    fn add(&self, other: &Self) -> Self {
        format!("({self} + {other})")
    }

    // etc
}
Enter fullscreen mode Exit fullscreen mode

We can use it as follows:

let t1: String = RawTensor::new(&[2, 2], &[1., 2., 3., 4.]);
let t2: String = RawTensor::new(&[2, 2], &[5., 6., 7., 8.]);
let r = t1.exp().add(&t2.log());

> r: "(new([2, 2], [1.0, 2.0, 3.0, 4.0]).exp() + new([2, 2], [5.0, 6.0, 7.0, 8.0]).log())"
Enter fullscreen mode Exit fullscreen mode

Or even:

let t1: String = "A".to_string();
let t2: String = "B".to_string();
let r = t1.exp().add(&t2.log());

> r: "(A.exp() + B.log())"
Enter fullscreen mode Exit fullscreen mode

We could generate source code or an abstract syntax tree this way, turning the interpreter into a compiler of sorts. That is the essence of the final tagless approach I described in depth in an earlier post. It has many extensibility advantages, which we'll take advantage of soon.

What does this have to do with automatic differentiation? AD is achieved by hard-coding how to differentiate primitive operations like addition and multiplication, and composing those primitive rules via the chain rule. The primitive operations define a language of tensor operations which we can interpret in a few ways - in particular, as straightforward tensor operations without differentiation via Tensor, as a forward mode differentiated program via Forward, or as a reverse mode differentiated program via Reverse. As for RawTensor we represent the primitive operations of the differentiable language as a trait, Diffable, and then implement this trait for each interpreter.

Let's start with the trait definition:

pub trait Diffable {
    type Elem: Num;

    fn log(&self) -> Self;
    fn exp(&self) -> Self;

    fn elementwise_add(&self, other: &Self) -> Self;
    fn elementwise_sub(&self, other: &Self) -> Self;
    fn elementwise_mul(&self, other: &Self) -> Self;
    fn elementwise_div(&self, other: &Self) -> Self;
    fn elementwise_pow(&self, exp: &Self) -> Self;
    fn elementwise_eq(&self, other: &Self) -> Self;

    fn sum(&self, axes: &[usize]) -> Self;
    fn max(&self, axes: &[usize]) -> Self;

    fn reshape(&self, shape: &[usize]) -> Self;
    fn permute(&self, dims: &[usize]) -> Self;
    fn expand(&self, shape: &[usize]) -> Self;
    fn pad(&self, padding: &[(usize, usize)]) -> Self;
    fn crop(&self, limits: &[(usize, usize)]) -> Self;

    fn new(shape: &[usize], data: &[Self::Elem]) -> Self;
    fn shape(&self) -> &[usize];
}
Enter fullscreen mode Exit fullscreen mode

Diffable's operations are similar to RawTensor's, and we can categorize them in much the same way - unary operations, binary operations, reduce-like operations, and shape-changing operations. Missing is the optimized fused multiply-add in RawTensor, which illustrates the difference in intent between RawTensor and Diffable. While we could make RawTensor differentiable, I'll now try to convince you we don't want to.

Fused multiply-add is an optimized operation that we need on the lowest level to have some hope of efficiency. It is likely that to make Tensorken more efficient, we'll need to add more special-purpose operations to better exploit hardware primitives, reduce memory usage, and so on.

We don't (necessarily) want to figure out how to differentiate those special-purpose operations - we'd like a small set of primitive operations, define their derivatives, and then compose those into higher-level operations. We then get derivatives of those higher-level operations for free, because differentiation is so beautifully composable. Separating Diffable from RawTensor allows us to add efficient, special-purpose operations to RawTensor without figuring out their derivatives. Vice versa, we can add operations to Diffable without having to change RawTensor and its implementations.

Before Diffable, we translated user-facing operations like matrix multiplication to RawTensor operations, which were interpreted by a concrete RawTensor like CpuRawTensor. Now we add another interpreter, Diffable, between the user-facing operations and RawTensor, which not only calculates the primal results but also derivatives. Diffable interpreters execute both primal and derivative calculations as RawTensor operations. That means we can combine all implementations of Diffable with all implementations of RawTensor. So we can do forward AD on the GPU, reverse AD on the CPU, or any other combination.

Let's make our way down the interpreter layers to see how this works in practice. We'll start with matrix multiplication and end up at CpuRawTensor.

Each of the sections that follow is one layer of the interpreter lasagne:

  • High-level tensor operations like matmul are translated to Diffable operations like sum and elementwise_mul.

  • A Diffable interpreter like Forward and Reverse translates primitive operations like sum and elementwise_mul to RawTensor operations, adding calculation of derivatives.

  • A RawTensor interpreter like CpuRawTensor executes the operations on a particular device.

a white plate topped with lasagna covered in spinach

A nicely layered design. Is it lunchtime yet? Photo by Parnis Azimi on Unsplash

User-facing layer: matrix multiplication in terms of Diffable

Here is a sketch of matmul, omitting everything that is not an operation on Diffable:

pub trait DiffableExt: Diffable
{
    fn matmul(&self, other: &Self) -> Self {
        // preconditions, shape manipulation omitted
        // special cases omitted

        let l = self.reshape(&l_shape);
        // shape manipulation omitted
        let r = other
            .reshape(&r_shape)
            .transpose(r_shape.ndims() - 1, r_shape.ndims() - 2);

        // after multiply: [..., m, o, n]
        l.mul(&r)

        // after sum: [..., m, o, 1]
        let sum = prod.sum(&[prod.shape().ndims() - 1]);

        // after reshape: [..., m, o]
        let s = sum.shape();
        sum.reshape(&s[..s.ndims() - 1])
    }
}

Enter fullscreen mode Exit fullscreen mode

Tensorken has three implementations of Diffable: Tensor, Forward, and Reverse. Tensor doesn't do any differentiation at all - it translates Diffable to RawTensor operations. Forward and Reverse augment the operations with their respective mode of AD. We'll come back to these later - first, we need to find a Rust vehicle to put the user-facing operations that are not in Diffable. We could re-implement them on each implementation of Diffable, but that is redundant. Instead, I've defined DiffableExt, a sub-trait of Diffable with a blanket implementation:

pub trait DiffableExt: Diffable
{
    // all the fns we want, like matmul, go here.
    // They'll need to be defined in terms of Diffable,
    // because that's all that's available.

    fn matmul(&self, other: &Self) -> Self { ... }
}

impl<T: Diffable> DiffableExt for T {}

Enter fullscreen mode Exit fullscreen mode

The advantage is we only have to implement Diffable on a concrete type, then anything defined on DiffableExt is available too (as long as DiffableExt is in scope.)

The first Diffable implementation: Tensor

We now need a concrete type to present to users. Tensor is that type. Its definition is mysteriously simple:

pub struct Tensor<T>(T);
Enter fullscreen mode Exit fullscreen mode

The idea is that the generic type argument T is a Diffable. Why not add the type constraint here? Because it's unnecessary - for all interesting implementations, T is Diffable. Constraining T here adds nothing new.

We can now make Tensor<T> implement Diffable for any T that's Diffable:

impl<T: Diffable> Diffable for Tensor<T> {
    type Elem = T::Elem;

    fn log(&self) -> Self {
        Tensor(self.0.log())
    }

    // etc
}
Enter fullscreen mode Exit fullscreen mode

All operations delegate to T. Full implementation here.

From Diffable to RawTensor

That gets us nowhere - we can have a differentiable Tensor<T> if we have a differentiable T. To execute tensor operations we need to get to a RawTensor. We can do that by interpreting Diffable operations as RawTensor operations. In Rust, this means creating a blanket implementation of Diffable for any RawTensor:

impl<T: Num, TTensor: RawTensor<Elem = T>> Diffable for TTensor {
    type Elem = T;

    fn log(&self) -> Self {
        self.log()
    }

    // etc
}
Enter fullscreen mode Exit fullscreen mode

Since Diffable is a subset of RawTensor, the implementation is again straightforward. A type like Tensor<CpuRawTensor> now works, and we can apply all operations in Diffable and DiffableExt to it.

It seems like we went around in a big circle. After Tensorken parts 1 and 2, we had a Tensor<T: RawTensor> with high-level operations like matmul and primitive operations on RawTensor. Now we have Tensor<T: Diffable> with high-level operations like matmul moved to DiffableExt, differentiable primitive operations on Diffable, and primitive executable operations still on RawTensor.

What we gained is the ability to have other Diffable implementations. We're going to use that ability now.

Forward-mode AD with Forward

The Forward type wraps T with extra stuff so we can transform and trace the computation to calculate the derivative. In interpreter terms, Forward is an interpreter for the Diffable language that calculates the derivative alongside the primal result using forward-mode AD. It does that by applying all tensor operations on a dual tensor.

The Forward type:

pub enum Forward<T> {
    Lift(T),
    Forward(T, T),
}
Enter fullscreen mode Exit fullscreen mode

Like for Tensor<T>, the T here is a Diffable tensor. The Forward case should make sense - it's the primal and the tangent tensors. We use the Lift case if we're not interested in computing the derivative of a tensor. Lifted tensors are treated as constants for the derivative computation. Another way of saying this is that their derivative is zero. We avoid many multiplications with zero by having a dedicated case instead of using the functionally equivalent Forward(t, zero).

We can understand jvp1's implementation now:

pub fn jvp1<T: Diffable + Clone, F>(f: F, at: &T, tangent: &T) -> (T, T)
where
    for<'a> F: Fn(&'a Forward<T>) -> Forward<T>,
{
    let forward = Forward::Forward(at.clone(), tangent.clone());
    let result = f(&forward);

    match result {
        Forward::Lift(p) => (p.clone(), p.zeros_like()),
        Forward::Forward(p, t) => (p, t),
    }
}
Enter fullscreen mode Exit fullscreen mode

We wrap the at and tangent arguments in Forward, then call f with them and unwrap the Forward from the result.

Forward must implement Diffable for this to work. Finally, we come to the implementation of differentiation rules for the primitive operations:

impl<T: Clone + Diffable> Diffable for Forward<T> {
    type Elem = T::Elem;

    fn elementwise_mul(&self, rhs: &Self) -> Self {
        self.binary::<MulOp<T>>(rhs)
    }

    fn sum(&self, axes: &[usize]) -> Self {
        self.unary::<SumOp, _>(axes)
    }

    // etc
}
Enter fullscreen mode Exit fullscreen mode

Full implementation here.

Calculating the primal and derivatives are encapsulated in Op structs. The unary and binary functions deal with handling Lift or Forward enum cases in one place, and delegate to a given Op struct for the calculation:

impl<T: Diffable> Forward<T> {
    fn unary<Op: UnaryOp<T, Args = TArgs> + UnaryDiffOp<T>, TArgs: ?Sized>(
        &self,
        args: &TArgs,
    ) -> Self {
        let (primal, op) = Op::f(self.primal(), args);
        match self {
            Forward::Lift(_) => Forward::Lift(primal),
            Forward::Forward(_, tan) => Self::Forward(primal, op.dfda(tan)),
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

binary is similar but more involved because it has 4 combinations of Lift and Forward.

Here's MulOp:

pub(crate) struct MulOp<TTensor>(TTensor, TTensor);

impl<TTensor: Clone + Diffable> BinaryOp<TTensor> for MulOp<TTensor> {
    fn f(a: &TTensor, b: &TTensor) -> (TTensor, Self) {
        (a.elementwise_mul(b), MulOp(a.clone(), b.clone()))
    }
}

impl<TTensor: Diffable> BinaryDiffOp<TTensor> for MulOp<TTensor> {
    fn dfda(&self, d: &TTensor) -> TTensor {
        d.elementwise_mul(&self.1) // da * b
    }

    fn dfdb(&self, d: &TTensor) -> TTensor {
        d.elementwise_mul(&self.0) // db * a
    }
}
Enter fullscreen mode Exit fullscreen mode

Differentiation rules often capture intermediate results or arguments of the primal computation. So f returns not only the result of the primal computation but also a struct to store whatever data is needed for the derivative computation. For MulOp, it captures the input tensors a and b.

dfda and dfdb define how to compute the derivative with respect to the first and second argument, given d, the derivative of downstream functions. The differentiation rule for elementwise tensor multiplication is essentially the same as for scalar multiplication.

Unary operations are similar but don't define dfdb:

pub(crate) struct SumOp(Vec<usize>);

impl<TTensor: Diffable> UnaryOp<TTensor> for SumOp {
    type Args = [usize];
    fn f(a: &TTensor, axes: &Self::Args) -> (TTensor, Self) {
        let r = a.sum(axes);
        (r, SumOp(axes.to_vec()))
    }
}

impl<TTensor: Diffable> UnaryDiffOp<TTensor> for SumOp {
    fn dfda(&self, d: &TTensor) -> TTensor {
        d.sum(&self.0)
    }
}
Enter fullscreen mode Exit fullscreen mode

SumOp only needs the reduced axes from the primal computation to calculate dfda The derivative of the sum is the sum of the derivatives, so we can apply the same sum to primal and tangent.

You can find all the ops here and here.

Forward<Forward<T>> for higher order derivatives

Reiterating this signature:

pub fn jvp1<T: Diffable + Clone, F>(f: F, at: &T, tangent: &T) -> (T, T)
where
    for<'a> F: Fn(&'a Forward<T>) -> Forward<T>
Enter fullscreen mode Exit fullscreen mode

Since the only requirement on T is that it's Diffable and Forward is Diffable, besides a Tensor<T> we can pass a Forward<T> to jvp1 to calculate higher-order derivatives.

let p: Tensor<CpuRawTensor<f32>> = Tr::scalar(2.0);
let ddf = diff1(|x: &Forward<Tensor<_>>| 
            diff1(|x: &Forward<Forward<Tensor<_>>>| x.tanh(), x), 
            &p
        );
Enter fullscreen mode Exit fullscreen mode

As we'll see soon, this same design will allow us to combine forward and reverse modes up to arbitrary depth, by building up types like Forward<Reverse<Tensor<..>>>.

This scheme works because the Diffable operations are implemented in terms of Diffable. That looks circular, but it's not: it's a stack of interpreters with a Tensor at the bottom, which translates Diffable operations to RawTensor operations:

Forward<Forward<...>>: Diffable ->... -> Tensor: Diffable -> RawTensor

Somewhat surprisingly, differentiating a differentiated program gets us the second derivative. One way to make sense of that is that if you compute second or third derivatives symbolically, that's exactly what you do: you apply the differentiation rules multiple times. If you want to do a "fun" exercise, you can work out that stacking Forward types amounts to operating on duals-of-duals up to the desired order. If you work out the differentiation rules by hand, you'll find that it yields the correct higher-order derivative.

Reverse-mode AD with Reverse

The implementation of Diffable for Reverse follows the same pattern as Forward but is more involved. Because reverse mode accumulates the derivative in a separate backward pass, we can no longer compute everything on the fly when we compute the primal. Instead, Reverse builds a trace of operations in a forward pass while calculating the primal result, then accumulates derivatives in the backward pass.

The difference with forward mode is visible in the signature of vjp:

pub fn vjp1<'b, 't, T: Diffable + Clone + 't, F>(f: F, at: &T) -> (T, PullBack<'t, T>)
where
    for<'a> F: Fn(&'a Reverse<'a, 't, T>) -> Reverse<'a, 't, T>
Enter fullscreen mode Exit fullscreen mode

Like jvp, it returns the primal result. Unlike jvp, it doesn't return the tangent, but instead a PullBack struct. The only available operation on that is call:

pub fn call(&self, cotangent: &T) -> Vec<T>
    where
        T: Diffable + Clone,
Enter fullscreen mode Exit fullscreen mode

This takes a cotangent tensor - in other words, a tensor with the same shape as the result of f, and returns the tangents of all the arguments of f. Here's how vjp is used:

pub fn value_and_gradn<'t, T: Diffable + Clone + 't, F>(f: F, at: &[&T]) -> (T, Vec<T>)
where
    for<'a> F: Fn(&'a [Reverse<'a, 't, T>]) -> Reverse<'a, 't, T>,
{
    // one forward pass, tracing
    let (primal, pullback) = vjpn(f, at);
    // one backward pass, accumulating derivatives
    let tangents = pullback.call(&primal.ones_like());
    // but we get multiple tangents in one go
    (primal, tangents)
}
Enter fullscreen mode Exit fullscreen mode

Other implementations for grad functions follow a similar pattern.

The details of how this is implemented (via a Trace type) are explained in my post on AD, so I won't repeat them here. It is not substantially different from the scalar case. Briefly, here is the Reverse type:

pub enum Reverse<'a, 't, T> {
    Lift(T),
    Reverse(&'a Trace<'t, T>, T, usize),
}
Enter fullscreen mode Exit fullscreen mode

Like Forward, it has a Lift case for tensors we don't want to differentiate. The Reverse case contains the primal T, and some administrative data to record the trace and do the backward pass.

The implementation of Diffable looks similar to Forward:

impl<T: Clone + Diffable> Diffable for Reverse<'_, '_, T> {
    type Elem = T::Elem;

    fn elementwise_mul(&self, rhs: &Self) -> Self {
        self.binary::<MulOp<T>>(rhs)
    }

    fn sum(&self, axes: &[usize]) -> Self {
        self.unary::<SumOp, _>(axes)
    }

    // other omitted
}
Enter fullscreen mode Exit fullscreen mode

Again we have unary and binary helper methods to deal with Lift and call the appropriate functions on the Op structs.

Interestingly, even though reverse mode calculates derivatives backward, from the output to the input, MulOp is identical for forward and reverse mode. This is true for all elementwise operations.

sum however is different from forward mode. In the backward pass, we get a d in the shape of the result of the sum (i.e. with fewer elements) and we need to produce a tensor in the shape of the input of sum. To do that, we need expand:

pub(crate) struct SumOp(Vec<usize>);

impl<TTensor: Diffable> UnaryOp<TTensor> for SumOp {
    type Args = [usize];
    fn f(a: &TTensor, axes: &Self::Args) -> (TTensor, Self) {
        let r = a.sum(axes);
        (r, SumOp(a.shape().to_vec()))
    }
}

impl<TTensor: Diffable> UnaryDiffOp<TTensor> for SumOp {
    fn dfda(&self, d: &TTensor) -> TTensor {
        d.expand(&self.0)
    }
}
Enter fullscreen mode Exit fullscreen mode

Full implementation for reverse mode is in ad_reverse.rs and the reverse operations are in ad_ops_reverse.rs.

After all that, we can run all the examples in the demo section. However, there is one remaining issue.

Un-blowing up matmul, again

The problem is serious. Repeating the (pseudo-code) implementation of matmul in DiffableExt:

pub trait DiffableExt: Diffable
{
    fn matmul(&self, other: &Self) -> Self {
        // preconditions, shape manipulation omitted
        // special cases omitted

        let l = self.reshape(&l_shape);
        // shape manipulation omitted
        let r = other
            .reshape(&r_shape)
            .transpose(r_shape.ndims() - 1, r_shape.ndims() - 2);

        // TROUBLE BEGINS HERE
        // after multiply: [..., m, o, n]
        l.mul(&r)

        // after sum: [..., m, o, 1]
        let sum = prod.sum(&[prod.shape().ndims() - 1]);

        // after reshape: [..., m, o]
        let s = sum.shape();
        sum.reshape(&s[..s.ndims() - 1])
    }
}
Enter fullscreen mode Exit fullscreen mode

See that mul followed by sum? In the second part of Tensors from Scratch, I explained that this blows up memory, to the point where this approach is utterly unscalable. The fused multiply-add function in RawTensor came to the rescue - we rewrote the separate sum and mul calls into one l.fused_multiply_add(&r, dims), which made it efficient. Now we've regressed to the previous bad situation. What gives?

First, Diffable doesn't have fused_multiply_add, so we can't write the optimized version directly. We could add fused_multiply_add to Diffable as a primitive operation, but then we have to define a differentiation rule for it in the various modes. One of the main reasons for Diffable's existence is to avoid that.

Second, while manually fusing mul and sum worked for this particular case, users may inadvertently write a mul followed by a sum, and fall into this trap themselves. Worse, while we're calculating derivatives by composing operations in forward or reverse mode, Tensorken itself may introduce a mul followed by a sum. Manually fusing all cases is not going to work. We need a better solution.

If we were writing a compiler, it'd be straightforward to go through the abstract syntax tree of tensor operations and transform any l.mul(r).sum(axes) into l.fused_multiply_add(r, axes). Can we do a similar optimization here?

Let's think about what's happening from the perspective of interpreters. We have defined a language for writing differentiable programs using the trait Diffable. Everything we do with tensors - matmul, crop, max, sigmoid as well as getting derivatives, is eventually a program in terms of the operations on Diffable. We have three interpreters for Diffable - one that translates the differentiable program to RawTensor operations, and two that augment the differentiable program with forward or reverse mode AD. No matter how many times we stack Diffable on top of Diffable, eventually the program gets run via a RawTensor interpreter.

We only have concrete RawTensor interpreters so far - that calculate the results on CPU or GPU, or that print a string representing the result. But we can also write an interpreter that spits out a new, optimized RawTensor interpreter, with all mul + sum fused into fused_multiply_add.

This technique - which I didn't invent at all, to be clear - is introduced more gradually and gracefully in my post on typed tagless final interpreters. Here I'll give a whirlwind tour of the implementation.

We'll use a type called Fuse<T>. T is the target optimized RawTensor. Whenever mul followed by sum is detected in the unoptimized, original RawTensor, Fuse rewrites the two operations to a fused equivalent.

enum FuseCtx {
    Sum(Vec<usize>),
    NotSum,
}

pub struct Fuse<T>(Rc<dyn Fn(&FuseCtx) -> T>);
Enter fullscreen mode Exit fullscreen mode

The function from FuseCtx to the fused T: RawTensor is a factory function we'll build up while interpreting the original RawTensor as Fuse<T>. In other words, Fuse<T> interprets RawTensor as a function that given a FuseCtx produces an optimized RawTensor. It works in two passes. A first pass builds up the factory function, then a second pass to run the function and get a new RawTensor.

Since Fuse only needs to fuse multiply and sum operations, it delays the application of sum, and instead passes Sum(axes) to the continuation via FuseCtx. The continuation calls the delayed sum if it can't fuse or fused_multiply_add if it can. Here's the implementation of mul where fusing happens:

impl<TRaw: RawTensor + Clone + 'static> RawTensor for Fuse<TRaw> {
    type Elem = TRaw::Elem;

    fn mul(&self, other: &Self) -> Self {
        let f_lhs = Rc::clone(&self.0);
        let f_rhs = Rc::clone(&other.0);
        let nextctx = FuseCtx::NotSum;
        Fuse::new(move |ctx| match ctx {
            FuseCtx::Sum(axes) => f_lhs(&nextctx).fused_multiply_add(&f_rhs(&nextctx), axes),
            FuseCtx::NotSum => f_lhs(&nextctx).mul(&f_rhs(&nextctx)),
        })
    }
}
Enter fullscreen mode Exit fullscreen mode

The context passed in the closure represents what the next operation is, from the perspective of the current operation. If it's sum, the Sum enum case, we fuse. If it's anything else, represented by NotSum, we know the operation has already been applied and we can't fuse. Since mul is not a sum, we pass NotSum as the next context.

Here is the implementation of sum:

fn sum(&self, axes: &[usize]) -> Self {
    let f = Rc::clone(&self.0);
    let my_axes = axes.to_vec();
    Fuse::new(move |ctx| match ctx {
        FuseCtx::Sum(sum_axes) => f(&FuseCtx::Sum(combine_axes(&my_axes, sum_axes))),
        FuseCtx::NotSum => f(&FuseCtx::Sum(my_axes.clone())),
    })
}
Enter fullscreen mode Exit fullscreen mode

We do not apply sum straight away to the resulting interpreter. Instead, we pass Sum through to the next operation, so it gets a chance to fuse it. Any operations that don't fuse, need to apply the delayed sum if they get the Sum enum. We might as well fuse consecutive sum calls into one by combining axes, hence the first match arm.

Fusing happens in two passes: the first pass builds the FuseCtx -> RawTensor function. The second pass creates the optimized RawTensor by calling the function:

impl<T> Fuse<T> {
    fn run(&self) -> T {
        (self.0)(&FuseCtx::NotSum)
    }
}
Enter fullscreen mode Exit fullscreen mode

Link to full implementation of fusing.

Now I can finally reveal the full Cpu32 and Wgpu32 types:

pub type Cpu32 = Tensor<ShapeTracker<Fuse<CpuRawTensor<f32>>>>;
pub type Wgpu32<'d> = Tensor<ShapeTracker<Fuse<WgpuRawTensor<'d, f32>>>>;
Enter fullscreen mode Exit fullscreen mode

The remaining unknown there is ShapeTracker. ShapeTracker is a RawTensor implementation that abstractly interprets the operations by only tracking tensor shapes. It delegates all operations to the RawTensor it wraps, except shape:

pub struct ShapeTracker<T>(ShapeStrider, T);

/// This implementation passes every operation through
/// to self.1, except for shape.
impl<TRaw: RawTensor> RawTensor for ShapeTracker<TRaw> {
    type Elem = TRaw::Elem;

    fn exp(&self) -> Self {
        Self(self.0.clone(), self.1.exp())
    }

    // etc

    fn shape(&self) -> &[usize] {
        self.0.shape()
    }
}
Enter fullscreen mode Exit fullscreen mode

Because Fuse does not track shapes but does need to implement RawTensor::shape, it can only return its shape by running the delayed computation. We don't want that - some derivative operations require access to the shape of tensors, and it would be bad if we had to run the tensor computation at that point.

ShapeTracker solves this for us - it can answer shape queries without executing tensor operations. The order is important here. ShapeTracker needs to wrap Fuse which needs to wrap the concrete CpuRawTensor or WgpuRawTensor.

I love it when a plan comes together

Thanks to the power of interpreters aka final tagless encoding, Tensorken gained a capable yet small and extensible AD implementation. So far, I'm really happy with how Tensorken turned out. I started seriously researching deep learning from an implementation perspective at the beginning of 2023 with only some prior exposure to automatic differentiation. I randomly ran into the typed tagless final interpreters paper while I was studying TinyGrad, and figured that TinyGrad's style would lend itself well to the tagless final style. I could not have hoped for a better outcome.

After that, I saw a post on Reddit praising JAX and immediately preferred the functional style over PyTorch's imperative AD interface. It was much more challenging to implement in Rust though! Those signatures look straightforward now, but it took a lot of struggling with closures, lifetimes and lifetimes and closures and then lifetimes some more before everything came together. All this to say - I got lucky trying an implementation style I hadn't ever used and struggled for a long time. When AD finally worked, it felt almost magical. Persistence is worth some IQ points.

Now that Tensorken has all the pieces of a full-fledged deep learning library, it's time to put it to the test. I intend to follow along with Andrej Karpathy's Zero to Hero neural networks course, translating it from PyTorch to Tensorken. At the end of that, we should have a home-grown, walking and talking nanoGPT. Without a doubt, there'll be many interesting problems in Tensorken itself to solve along the way.

Many thanks for reading!

References

Top comments (0)