DEV Community

Shrijith Venkatramana
Shrijith Venkatramana

Posted on

3 3 3 3 3

Fixing A Bug in micrograd BackProp (As Explained by Karpathy)

Hi there! I'm Shrijith Venkatrama, founder of Hexmos. Right now, I’m building LiveAPI, a tool that makes generating API docs from your code ridiculously easy.

A Bug In Our Code

In the previous post, we got automatic gradient calculation going for the whole expression graph.

However, it has a tricky bug. Here's a sample program that invokes the bug:

a = Value(3.0, label='a')
b = a + a  ;  b.label = 'b'

b.backward()
draw_dot(b)
Enter fullscreen mode Exit fullscreen mode

Buggy Graph

In the above, forward pass looks alright:

b = a + a = 3 + 3 = 6
Enter fullscreen mode Exit fullscreen mode

But think about the backward pass:

b = a + a
db/da = 1 + 1 = 2
Enter fullscreen mode Exit fullscreen mode

The answer should be 2, but we've got 1 as the a.grad value.

The problem is in the __add__ operation of Value class:

class Value:
  def __init__(self, data, _children=(), _op='', label=''):
    self.data = data
    self._prev = set(_children)
    self._op = _op
    self.label = label
    self.grad = 0.0
    self._backward = lambda: None # by default doesn't do anything (for a leaf
                                  # node for ex)

  def __repr__(self):
    return f"Value(data={self.data})"

  def __add__(self, other):
    out = Value(self.data + other.data, (self, other), '+')
    # out.grad = 1 here

    # derivative of '+' is just distributing the grad of the output to inputs
    def backward():
      self.grad = 1.0 * out.grad # a.grad = 1
      other.grad = 1.0 * out.grad # again a.grad = 1

    out._backward = backward
Enter fullscreen mode Exit fullscreen mode

Here is another example of a bug:

a = Value(-2.0, label='a')
b = Value(3.0, label='b')

d = a * b   ;   d.label = 'd'
e = a + b   ;   e.label = 'e'
f = d * e   ;   f.label = 'f'

f.backward()
draw_dot(f)
Enter fullscreen mode Exit fullscreen mode

Another Bug Example

We know that for multiplication operation:

self.grad = other.data * out.grad

d.grad = e.data * out.grad = 1 * 1 = 1

e.grad = d.data * out.grad = -6 * 1 = -6 
Enter fullscreen mode Exit fullscreen mode

So far, so good.

Let's look for the next stage:

self.grad = other.data * out.grad

b.grad = a.data * d.grad = -2 * 1 = -2

But, if we consider the expression,

e = a + b

a.grad = b.grad = e.grad = -6
Enter fullscreen mode Exit fullscreen mode

So we have the conflict - of b.grad = -6 (addition) and b.grad = -2 (multiplication)

So the general problem here is that - when a Value is used multiple times, there is a conflict and overwriting happens.

So first maybe the grad results of addition are updated, but then in another iteration the grad results of multiplication are also updated - overwriting the previous value.

Solving the bug - "Accumulate Gradients" rather than Replacing Them

The Wikipedia page for Chain Rule a section on multivariable case.

The gist of the general solution is that gradients must be accumulated, rather than replaced, in calculating gradients.

So, the new Value class is as follows where in _backwards we accumulate, rather than replace gradients:

class Value:
  def __init__(self, data, _children=(), _op='', label=''):
    self.data = data
    self._prev = set(_children)
    self._op = _op
    self.label = label
    self.grad = 0.0
    self._backward = lambda: None # by default doesn't do anything (for a leaf
                                  # node for ex)

  def __repr__(self):
    return f"Value(data={self.data})"

  def __add__(self, other):
    out = Value(self.data + other.data, (self, other), '+')

    # derivative of '+' is just distributing the grad of the output to inputs
    def backward():
      self.grad += 1.0 * out.grad
      other.grad += 1.0 * out.grad

    out._backward = backward

    return out

  def __mul__(self, other):
    out = Value(self.data * other.data, (self, other), '*')

    # derivative of `mul` is gradient of result multiplied by sibling's data
    def backward():
      self.grad += other.data * out.grad
      other.grad += self.data * out.grad

    out._backward = backward

    return out

  def tanh(self):
      x = self.data
      t = (math.exp(2*x) - 1) / (math.exp(2*x) + 1)
      out = Value(t, (self, ), 'tanh')

      # derivative of tanh = 1 - (tanh)^2
      def backward():
        self.grad += (1 - t**2) * out.grad

      out._backward = backward
      return out

  def backward(self):
    topo = []
    visited = set()
    def build_topo(v):
        if v not in visited:
            visited.add(v)
            for child in v._prev:
                build_topo(child)
            topo.append(v)
    build_topo(self)

    self.grad = 1.0
    for node in reversed(topo):
        node._backward()

Enter fullscreen mode Exit fullscreen mode

Now the gradient calculations are correct:

Calc1

Calc2

Reference

The spelled-out intro to neural networks and backpropagation: building micrograd - YouTube

API Trace View

How I Cut 22.3 Seconds Off an API Call with Sentry 🕒

Struggling with slow API calls? Dan Mindru walks through how he used Sentry's new Trace View feature to shave off 22.3 seconds from an API call.

Get a practical walkthrough of how to identify bottlenecks, split tasks into multiple parallel tasks, identify slow AI model calls, and more.

Read more →

Top comments (0)

A Workflow Copilot. Tailored to You.

Pieces.app image

Our desktop app, with its intelligent copilot, streamlines coding by generating snippets, extracting code from screenshots, and accelerating problem-solving.

Read the docs

AWS GenAI LIVE!

GenAI LIVE! is a dynamic live-streamed show exploring how AWS and our partners are helping organizations unlock real value with generative AI.

Tune in to the full event

DEV is partnering to bring live events to the community. Join us or dismiss this billboard if you're not interested. ❤️