DEV Community

Jimin Lee
Jimin Lee

Posted on • Originally published at Medium

Demystifying loss.backward(): How PyTorch Autograd Actually Works

If you use PyTorch, there’s one line of code you probably type out of sheer muscle memory: loss.backward().

When you first learn neural networks, you spend hours deriving backpropagation equations and understanding chain rules. Then you meet PyTorch, type loss.backward(), and it just... works. It feels like magic. All that complex calculus is handled in a single line.

But as engineers, "It’s magic" isn’t a satisfying explanation. It’s actually a bit unsettling. What exactly is happening under the hood when we call that function?

Today, we’re going to look behind the curtain at Autograd, the engine that makes this magic possible.


0. The Blueprint: The Training Loop

Before we tear apart the code, let’s quickly recap the standard workflow of training a deep learning model. Training is essentially an iterative process of tweaking weights (parameters) to minimize error. It usually follows these four steps:

  1. Forward Pass: Push input data through the model to get a prediction.

  2. Loss Calculation: Compare the prediction to the actual target to calculate the loss.

  3. Backward Pass: Determine how much each parameter contributed to the error. We traverse from the loss back to the start, calculating gradients.

  4. Weight Update: Adjust the parameters based on the calculated gradients to reduce the error in the next round.

By repeating this, the model gets smarter.

The "Hello World" of Deep Learning

Here is how those steps translate into standard PyTorch code. If you’ve done any PyTorch tutorial, this will look familiar:

for epoch in range(epochs):
    # 1. Clear the whiteboard (Reset Gradients)
    optimizer.zero_grad()

    # 2. Take the test (Forward pass)
    output = model(input)

    # 3. Grade the test (Compute Loss)
    loss = criterion(output, target)

    # 4. Find out why we failed (Backward pass)
    loss.backward()

    # 5. Study and improve (Update Parameters)
    optimizer.step()
Enter fullscreen mode Exit fullscreen mode

Let's dig into the five simple lines that power this complex process.

1. Autograd: The Puppet Master

Autograd is PyTorch’s automatic differentiation engine. Think of it as a silent stenographer. Whenever you create a Tensor and perform operations (add, sub, multiply), Autograd watches and records everything. It notes down what data was used, what operation was performed, and what the result connects to next.

The Dynamic Computational Graph

Autograd records this history in a structure called a Computational Graph. This graph maps out every operation: the inputs, the operator, the outputs, and the sequence of events.

PyTorch is famous for using a Dynamic Computational Graph. This means the graph isn't pre-compiled; it’s drawn on the fly as your code runs.

Note: PyTorch 2.0 introduced torch.compile to allow for static graph optimization, which we'll touch on at the end.

import torch

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([3.0], requires_grad=True)

z = x * y
loss = z ** 2
Enter fullscreen mode Exit fullscreen mode

Setting requires_grad=True tells Autograd: "Watch this tensor closely." When we run the code above, PyTorch builds a graph that looks something like this:

x, y (Leaf Nodes) → MultiplicationzPowerloss (Root)

2. The Core Components: grad_fn, next_functions, ctx, and .grad

To understand how Autograd navigates this graph, we need to meet the internal variables that hold it all together.

Member 1: grad_fn (The Signpost)

Every Tensor created by an operation has a grad_fn attribute. This attribute answers the question: "What operation created me?"

If a tensor was created by addition, it gets AddBackward. If by multiplication, MulBackward. It acts as a node in our graph and holds the instructions for how to differentiate that specific operation.

Member 2: next_functions (The Connection)

If grad_fn says "who I am," next_functions says "who my parents are." It points to the grad_fn of the input tensors that created the current tensor. During backpropagation, this tells PyTorch where to send the calculated gradients next. This forms the edges of our graph.

We can actually inspect this structure with a recursive function:

import torch

x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([3.0], requires_grad=True)

z = x * y
loss = z ** 2

def print_graph(grad_fn, depth=0):
    indent = '  ' * depth

    if grad_fn is None:
        return

    print(f"{indent}└─ {type(grad_fn).__name__}")

    for next_fn, _ in grad_fn.next_functions:
        print_graph(next_fn, depth + 1)

print("\n=== Autograd Graph Structure ===")
print_graph(loss.grad_fn)
Enter fullscreen mode Exit fullscreen mode

Output:

=== Autograd Graph Structure ===
└─ PowBackward0
  └─ MulBackward0
    └─ AccumulateGrad
    └─ AccumulateGrad
Enter fullscreen mode Exit fullscreen mode

The print_graph() function takes a Node (grad_fn) as input and traverses the next_functions list to map out the history. Let’s trace exactly what this output is telling us:

  1. PowBackward0: The journey starts at loss. Since loss was calculated using the formula loss = z ** 2, it possesses a grad_fn corresponding to the power operation, which is named PowBackward0.

  2. MulBackward0: The input used to create that square operation was z. Looking back at how z was born, we see z = x * y. Because this was a multiplication operation, its grad_fn is MulBackward0.

  3. AccumulateGrad: Finally, MulBackward0 lists AccumulateGrad as its parents. This is a special marker indicating a Leaf Node.

A Leaf Node is a Tensor that has requires_grad=True but wasn't created as the result of an intermediate operation. These are the tensors we created manually, like x and y. Since we made them ourselves, they don't have a natural grad_fn (no operation created them). So, Autograd artificially attaches an AccumulateGrad node to them so they can be connected to the graph and store gradients during backpropagation.

Member 3: ctx (The Storage Locker)

To calculate a derivative (gradient), you often need the original input data.

Example: The derivative of x^2 is 2x. To calculate 2x, you need to know what x was.

During the Forward Pass, PyTorch realizes: "I might need this value later for calculus." It saves these tensors in a special container called ctx (context).

This is a common source of memory issues. Even if your model parameters fit on the GPU, the training process consumes much more memory because ctx is holding onto all those intermediate tensors needed for the backward pass.

Member 4: .grad (The Report Card)

When you call loss.backward(), PyTorch navigates the graph, calculating gradients at every step. The final result for each parameter is stored in its .grad attribute.

Crucially, .grad values are cumulative. PyTorch doesn't overwrite them; it adds to them. (We'll explain why in section 5).

Let's see ctx and .grad in action by creating a custom Autograd function:

import torch

class MySquare(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_tensor):
        # 1. Forward Pass
        # Save input to the 'ctx locker' for later use
        ctx.save_for_backward(input_tensor) 
        return input_tensor ** 2

    @staticmethod
    def backward(ctx, grad_output):
        # 2. Backward Pass
        # Open 'ctx locker' and retrieve the saved input
        input_tensor, = ctx.saved_tensors

        # Calculate Gradient: 2x * (incoming gradient)
        gradient = (2 * input_tensor) * grad_output

        # Pass this gradient down the line
        return gradient


x = torch.tensor([3.0], requires_grad=True)

# Run Forward
y = MySquare.apply(x) 

print(f"Forward result (3^2): {y.item()}")  # 9.0

# Run Backward
# y.backward() is implicitly y.backward(torch.tensor(1.0))
y.backward()

# Check the calculated gradient
# Logic: (2 * 3.0) * 1 = 6.0
print(f"Backward result (2*3): {x.grad.item()}")  # 6.0
Enter fullscreen mode Exit fullscreen mode

Output:

Forward result (3^2): 9.0
Backward result (2*3): 6.0
Enter fullscreen mode Exit fullscreen mode

Here is exactly what happened in that code block:

In the Forward phase, MySquare calculates the square, but it also performs a critical step: it calls ctx.save_for_backward(input_tensor) to save the original input value into the ctx storage.

Then, in the Backward phase, it uses ctx.saved_tensors to retrieve that saved input (which we stored earlier) to calculate the derivative. This calculated value eventually lands in the tensor's .grad attribute, which we confirm in the last line by printing x.grad.item()

3. The loss.backward() Journey

Now, let's look closer at what happens when loss.backward() is called. PyTorch starts at loss and travels backward through the graph toward the inputs, calculating derivatives for each node.

Where is the final destination of this journey? It’s the model inputs and the Tensors created by the user—typically the Weights and Biases. These nodes sit at the end of the graph and are called Leaf Nodes.

(The nodes in the middle of the graph are simply intermediate Tensors created as results of operations.)

Here is the actual flow:

  1. Departure: The journey begins at loss.

  2. Opening the ctx Locker: It accesses ctx via grad_fn to retrieve the "intermediate tensors" saved earlier.

  3. Calculating Derivatives: It calculates the gradient using the data and the chain rule.

  4. Propagation & Accumulation: It passes the calculated gradient to the previous nodes using next_functions. If it reaches a Leaf Node (a parameter), it adds the value to that Tensor's .grad attribute.

4. optimizer.step(): Making the Move

Once backward() is finished, every parameter knows which direction it needs to move to reduce the error. The calculated directions are sitting in the .grad attributes.

The job of actually modifying the Tensor values belongs to optimizer.step(). While its job is simple—update the Leaf Nodes (Weights, Biases)—there are two implementation details worth noting.

First, optimizer.step() technically involves Tensor math (e.g., subtracting a value from the current weight). In theory, this operation should also be recorded by ctx. However, updating weights is not part of the model's forward training logic; we don't want to backpropagate through the optimization step itself. Therefore, optimizer.step() works internally by turning on torch.no_grad(), ensuring these updates aren't recorded in the graph.

Second, the update happens as an In-place operation. Usually, PyTorch tensor operations create new Tensors. But in the optimizer, we want to modify the existing Tensors directly. For efficiency and to preserve the graph connections, the optimizer uses in-place methods like w.sub_(lr * grad).

5. optimizer.zero_grad(): A Fresh Start

We’ve reached the final piece of the training loop: zero_grad(). We know its role is to reset the gradients. But why do we have to do this at the start of every loop?

This is because PyTorch accumulates gradients. When PyTorch calculates a new gradient, it doesn't overwrite the existing value in .grad; it adds the new value to the old one.

Therefore, once a mini-batch is finished, we must call zero_grad() to wipe the slate clean. It’s like erasing a chalkboard after solving a math problem so you have space to solve the next one.

But this raises a question: Why doesn't PyTorch just automatically zero out the gradients after optimizer.step()? Why make us type this extra line?

This is actually a feature which can be used to enable Gradient Accumulation. This technique is a lifesaver when you are running out of GPU memory.

For example, let’s say you want to train with a Batch Size of 64, but your GPU only has enough memory for a Batch Size of 16. You can simulate the larger batch size using accumulation:

  1. Set your real batch size to 16.

  2. Run the training loop 4 times without calling zero_grad(). (16 x 4 = 64)

  3. During these 4 loops, the gradients will pile up on top of each other in .grad.

  4. After the 4th loop, call optimizer.step() to update the weights.

  5. Now call optimizer.zero_grad() to prepare for the next set.

By doing this, you physically calculated 16 items at a time, but mathematically, you performed an update equivalent to a batch of 64.

Appendix 1: model.eval() vs torch.no_grad()

These two commands are both used during inference, but they serve very different purposes. It's easy to confuse them.

1) model.eval(): "Combat Mode"

When you are studying for a test, you might use various tools to help you learn: you might take mock exams, use flashcards, or intentionally solve harder problems to challenge yourself.

But when you take the actual exam, you don't do those things. You just solve the problems in front of you.

Deep learning models are the same. During training (studying), they use techniques like Dropout (randomly turning off neurons) and BatchNorm (calculating statistics based on the current batch). But during inference (the exam), they need to behave differently. model.eval() tells the model: "We are in a real situation now. Turn off Dropout and freeze the BatchNorm statistics!"

However, this does not save memory. Even in eval mode, Autograd continues to build the graph and store values in ctx during the forward pass.

2) torch.no_grad(): "Stop Recording"

torch.no_grad() is a direct command to the Autograd engine: "Stop recording. Take a break."

Autograd meticulously records everything solely for the purpose of the backward pass. But during inference, we don't need to do a backward pass. If Autograd keeps recording, it's just wasting memory and time.

Inside a with torch.no_grad(): block, PyTorch stops storing intermediate Tensors in ctx. This effectively reduces memory usage and speeds up computation.

Conclusion: Since they have different purposes, you should usually use both when running validation or testing:

# 1. Change Model Behavior (Disable Dropout, etc.)
model.eval() 

# 2. Save Memory & Speed Up (Stop tracking operations)
with torch.no_grad():
    output = model(data)
Enter fullscreen mode Exit fullscreen mode

Appendix 2: torch.compile (PyTorch 2.0+)

The greatest strength of early PyTorch was the flexibility of its Dynamic Graph. However, rebuilding the graph from scratch every single iteration is inherently slower than the Static Graph approach (where the graph is fixed and optimized once).

PyTorch 2.0 introduced torch.compile to solve this performance gap.

1) The Principle: Operator Fusion

The key to torch.compile's speed is Operator Fusion.

Modern GPUs calculate math incredibly fast. However, moving data from memory to the chip (and back) is relatively slow. In fact, memory access is often the bottleneck.

Without torch.compile (Eager Mode), the process looks like this:

  1. Read value from GPU memory.

  2. Multiply it by 2.

  3. Write the result back to memory.

  4. Read that result from memory again.

  5. Add 5 to it.

  6. Write the final result to memory.

Notice how we are reading and writing to memory multiple times?

torch.compile optimizes this:

  1. It analyzes the entire graph and notices we are doing a multiply followed by an add.

  2. It fuses these into a single custom function (kernel).

  3. New Flow: Read value from memory.

  4. Multiply by 2 AND add 5 in one go (inside the GPU).

  5. Write the final result to memory.

By reducing the number of "slow" memory accesses, the overall speed increases significantly.

2) How to use it

Using it is incredibly simple. Just wrap your model:

import torch

model = MyNet()
optimized_model = torch.compile(model)

# Use it exactly the same way as before
output = optimized_model(input)
Enter fullscreen mode Exit fullscreen mode

Wrap Up

We started with loss.backward() as a magic spell. We ended with an understanding of the Autograd engine that powers it.

Now you know that when you call loss.backward(), it's not just a simple function call. PyTorch is following the grad_fn signposts, opening the ctx storage lockers, and diligently traversing the graph in reverse.

Understanding these internals is more than just intellectual trivia. It gives you practical engineering skills:

  • Memory Issues: You now understand that OOM errors are often caused by ctx holding onto tensors, not just the model size.

  • Custom Operations: You know how to inherit from Function to create custom layers that don't break the gradient chain.

  • Gradient Accumulation: You understand why zero_grad() is manual and how to leverage it to train huge models on small GPUs.

  • Inference Optimization: You know exactly why torch.no_grad() is essential for performance, distinct from model.eval().

Hopefully, the next time you type loss.backward(), you’ll see the invisible graph being traversed and appreciate the engineering happening behind the scenes.

Top comments (1)

Collapse
 
shahrouzlogs profile image
Shahrouz Nikseresht

Great explanation! Your breakdown of the computational graph and how loss.backward()propagates gradients is really clear. A deeper dive like this is super helpful for anyone learning PyTorch. Well done!