DEV Community

Nilavukkarasan R
Nilavukkarasan R

Posted on • Edited on

Neural Network Optimizers: Training at Scale

"Adapt what is useful, reject what is useless, and add what is specifically your own."
Bruce Lee

From 4 Examples to 60,000

Backpropagation learned XOR from 4 training examples. Compute the gradient using all 4, update the weights, repeat. Every update sees the complete picture.

Now consider MNIST: 60,000 handwritten digit images, each 28×28 pixels. The task is to look at an image and predict which digit (0-9) it represents. The network needs 784 inputs, a hidden layer, and 10 outputs. Roughly 100,000 weights.

Computing the gradient using all 60,000 examples requires 60,000 forward and backward passes per update. On a simple NumPy implementation, that's a few seconds per update. Training for 100 epochs takes several minutes.

That's just MNIST. Models like GPT-4o and Claude train on trillions of tokens. Full-batch gradient descent doesn't scale.

You Don't Need Every Example

Think about cooking. You don't taste every grain of rice to know if you need more salt. A spoonful tells you enough.

That's mini-batch gradient descent. Instead of computing the gradient from all 60,000 examples, grab a small batch (say 64), compute the gradient from those, update the weights, grab the next 64, repeat.

for each epoch:
    shuffle training data
    for each mini-batch of 64:
        forward pass
        compute loss
        backward pass
        update weights
Enter fullscreen mode Exit fullscreen mode

Each mini-batch gradient is noisy, not the exact direction from all 60,000 examples. But it points roughly right. And it's fast. Instead of one slow update per epoch using all data, you get hundreds of quick updates. Training that took minutes with full-batch finishes in seconds with mini-batches.

One complete pass through all the data is an epoch. With 60,000 examples and batch size 64, one epoch is 937 updates. We shuffle the data before each epoch so the mini-batches differ every time. This randomness (the "stochastic" in stochastic gradient descent) prevents the network from memorizing the order of examples.

Not All Weights Need the Same Push

Mini-batches solve the speed problem. But remember the radio from the last post? Seed 5 with a small network got stuck between stations. At MNIST scale, this problem gets worse. With 100,000 weights, some are tuning into strong signals and getting large gradients on every update. Others are listening for faint signals and barely getting any gradient at all. One learning rate can't serve both. The loud signals overshoot while the faint ones barely move.

This is where optimizers diverge.

SGD applies the same learning rate to every weight. It's the basic radio dial. Turn at one speed, hope for the best. If the station is strong, you find it. If it's faint, you might turn right past it.

Momentum keeps a running average of past gradients. If the last ten updates all pushed a weight in the same direction, momentum makes the next push bigger. If the updates keep flipping direction (up, down, up, down), momentum cancels them out and the weight stays steady. It smooths out the noise from mini-batches.

Adam does what momentum does, plus one more thing: it tracks how large each weight's gradients typically are. A weight that always gets big gradients is already moving fast, so Adam gives it smaller steps to avoid overshooting. A weight that gets tiny gradients is barely moving, so Adam gives it bigger steps to catch up. Every weight gets its own learning rate, adjusted automatically.

In practice, Adam is the default choice for most problems. It converges faster and is less sensitive to the initial learning rate.

For more details about the math model behind --> optimizers

See It

Open the playground and train all three optimizers on MNIST side by side. Watch which one pulls ahead first.

The gap is most visible in the first few epochs. Adam adapts quickly because it adjusts per weight. SGD treats every weight the same and takes longer to find its footing.

Optimizer comparison on MNIST

What's Next

We can now train on real data, efficiently. Backprop computes gradients, mini-batches make it fast, Adam adapts the learning rate per weight. 99% accuracy on MNIST in minutes.

But train longer and something breaks. Training accuracy keeps climbing, past 99%. Test accuracy stalls and drops. The network isn't learning patterns anymore. It's memorizing training examples.

That gap between training and test accuracy is called overfitting. Closing it is the next problem.


References:
Kingma, D. P., & Ba, J. (2014). Adam: A Method for Stochastic Optimization

Series: From Perceptrons to Transformers | Code: GitHub

Top comments (0)