DEV Community

Cover image for Chapter 7: The Training Loop and Adam Optimiser
Gary Jackson
Gary Jackson

Posted on • Originally published at garyjackson.dev

Chapter 7: The Training Loop and Adam Optimiser

What You'll Build

A complete training loop that processes documents, computes loss, backpropagates gradients, and updates parameters using the Adam optimiser.

Depends On

All previous chapters.

The Training Loop

A training step is just five things in a row:

  1. Pick a document and tokenize it
  2. Forward pass for each token, building up the loss
  3. Backward pass to fill in every gradient
  4. Nudge the parameters using those gradients
  5. Zero the gradients out before the next step

Step 4 is where Adam lives. Before we look at the code, it's worth slowing down on what Adam actually does and why we use it.

Understanding Adam

You could update parameters with simple gradient descent: p.Data -= learningRate * p.Grad. Adam is smarter in two ways.

Momentum (momentum). Instead of reacting to each individual gradient, Adam tracks a running average of recent gradients. This smooths out noisy updates, like a rolling ball that doesn't reverse direction every time it hits a bump.

Squared gradient average (squaredGradAvg). Adam also tracks the running average of each parameter's squared gradient. Squaring serves two purposes:

  1. Makes values positive. We want to track how large gradients have been, not their direction. A gradient of -5 and +5 should both count as "large".
  2. Emphasises larger gradients. A gradient of 10 contributes 100 to the average, a gradient of 1 contributes just 1. So a parameter with occasional huge gradients gets dampened more than one with steady moderate gradients.

When the update happens, Adam divides by the square root of this number. Parameters with consistently large gradients get a smaller effective step size and vice versa, so each parameter ends up with its own adapted learning rate. The squaring-then-square-rooting gives us what's called the RMS (root mean square) of the gradient, effectively a rolling "typical size" of recent gradients.

Bias correction (correctedMomentum, correctedSquaredGrad). Because momentum and squaredGradAvg start at zero, they're biased toward zero in early steps. The correction factors 1 / (1 - beta^(step+1)) compensate for this warm-up period.

Learning rate decay. The learning rate decreases linearly over training: currentLearningRate = learningRate * (1 - step/numSteps). This allows large steps early on (when parameters are far from good values) and smaller, more precise steps later. The decay reaches zero at the final step, so the model makes progressively smaller adjustments as training continues, effectively locking in what it has learned.

The constants MomentumSmoothing, SquaredGradSmoothing, and Epsilon in the code are Adam's hyperparameters. MomentumSmoothing controls how much smoothing is applied to the momentum (higher = more smoothing, more memory of past gradients), SquaredGradSmoothing does the same for the squared gradient average, and Epsilon is a tiny number that prevents division by zero. Standard defaults are MomentumSmoothing=0.9, SquaredGradSmoothing=0.999, but we use 0.85, 0.99 for faster training on this small problem.

Code

// --- Chapter7Exercise.cs ---

using static MicroGPT.Helpers;

namespace MicroGPT;

public static class Chapter7Exercise
{
    public static void Run()
    {
        var random = new Random(42);
        List<string> docs = Tokenizer.LoadDocs("input.txt", random);
        var tokenizer = new Tokenizer(docs);
        Console.WriteLine($"num docs: {docs.Count}");
        Console.WriteLine($"vocab size: {tokenizer.VocabSize}");

        // ── Simplified model (replaced by GptModel in Chapter 11) ──
        int embeddingSize = 16;
        int maxSequenceLength = 8;
        int numSteps = 1000;
        double learningRate = 0.01;

        List<List<Value>> tokenEmbeddings = CreateMatrix(
            random,
            tokenizer.VocabSize,
            embeddingSize
        );
        List<List<Value>> positionEmbeddings = CreateMatrix(
            random,
            maxSequenceLength,
            embeddingSize
        );
        List<List<Value>> outputProjection = CreateMatrix(
            random,
            tokenizer.VocabSize,
            embeddingSize
        );

        // Collect all parameters into a flat list for the optimiser
        var paramsList = new List<Value>();
        foreach (List<Value> row in tokenEmbeddings)
        {
            paramsList.AddRange(row);
        }

        foreach (List<Value> row in positionEmbeddings)
        {
            paramsList.AddRange(row);
        }

        foreach (List<Value> row in outputProjection)
        {
            paramsList.AddRange(row);
        }

        Console.WriteLine($"num params: {paramsList.Count}");

        // ── Adam optimiser ──
        // Note: the standard Adam defaults are MomentumSmoothing=0.9, SquaredGradSmoothing=0.999.
        // We use more aggressive values here to train faster on this small problem.
        const double MomentumSmoothing = 0.85,
            SquaredGradSmoothing = 0.99,
            Epsilon = 1e-8;
        double[] momentum = new double[paramsList.Count];
        double[] squaredGradAvg = new double[paramsList.Count];

        // Reusable buffers for Backward. These are what the parameterless Backward()
        // overload from Chapter 2 allocates internally on every call. Here we hoist
        // them out of the hot loop so 1,000 training steps don't allocate 1,000
        // fresh copies.
        var topo = new List<Value>();
        var visited = new HashSet<Value>();
        var backwardStack = new Stack<(Value, int)>();

        for (int step = 0; step < numSteps; step++)
        {
            string doc = docs[step % docs.Count];
            var tokens = new List<int> { tokenizer.Bos };
            tokens.AddRange(doc.Select(tokenizer.Encode));
            tokens.Add(tokenizer.Bos);
            // Any name longer than maxSequenceLength - 1 is silently truncated here.
            int tokenCount = Math.Min(maxSequenceLength, tokens.Count - 1);

            var losses = new List<Value>();
            for (int posId = 0; posId < tokenCount; posId++)
            {
                List<Value> logits = Forward(
                    tokens[posId],
                    posId,
                    tokenEmbeddings,
                    positionEmbeddings,
                    outputProjection,
                    embeddingSize
                );
                List<Value> probabilities = Softmax(logits);
                losses.Add(-probabilities[tokens[posId + 1]].Log());
            }

            // Average the per-position losses into a single scalar
            var loss = new Value(0);
            foreach (Value l in losses)
            {
                loss += l;
            }

            loss *= 1.0 / tokenCount;

            foreach (Value p in paramsList)
            {
                p.Grad = 0;
            }

            topo.Clear();
            visited.Clear();
            backwardStack.Clear();
            loss.Backward(topo, visited, backwardStack);

            double currentLearningRate = learningRate * (1 - (double)step / numSteps);
            for (int i = 0; i < paramsList.Count; i++)
            {
                Value p = paramsList[i];
                momentum[i] =
                    MomentumSmoothing * momentum[i] + (1 - MomentumSmoothing) * p.Grad;
                squaredGradAvg[i] =
                    SquaredGradSmoothing * squaredGradAvg[i]
                    + (1 - SquaredGradSmoothing) * Math.Pow(p.Grad, 2);
                double correctedMomentum =
                    momentum[i] / (1 - Math.Pow(MomentumSmoothing, step + 1));
                double correctedSquaredGrad =
                    squaredGradAvg[i] / (1 - Math.Pow(SquaredGradSmoothing, step + 1));
                p.Data -=
                    currentLearningRate
                    * correctedMomentum
                    / (Math.Sqrt(correctedSquaredGrad) + Epsilon);
            }

            if (step == 0 || (step + 1) % 100 == 0)
            {
                Console.WriteLine(
                    $"step {step + 1, 4} / {numSteps, 4} | loss {loss.Data:F4}"
                );
            }
        }
    }

    private static List<Value> Forward(
        int tokenId,
        int posId,
        List<List<Value>> tokenEmbeddings,
        List<List<Value>> positionEmbeddings,
        List<List<Value>> outputProjection,
        int embeddingSize
    )
    {
        List<Value> tokenEmbedding = tokenEmbeddings[tokenId];
        List<Value> positionEmbedding = positionEmbeddings[posId];
        var x = new List<Value>();
        for (int i = 0; i < embeddingSize; i++)
        {
            x.Add(tokenEmbedding[i] + positionEmbedding[i]);
        }

        return Linear(x, outputProjection);
    }
}
Enter fullscreen mode Exit fullscreen mode

Code Walkthrough

Breaking the code down section by section:

Setup (docs through tokenizer). Load the dataset, build the tokenizer, print stats. Same as Chapter 6.

Hyperparameters (embeddingSize through learningRate). Two we've seen (embeddingSize, maxSequenceLength), and two new ones:

  • numSteps = 1000 - how many training iterations to run
  • learningRate = 0.01 - the starting size of each parameter update

Model setup (tokenEmbeddings through outputProjection). Three embedding/projection matrices. The whole setup will be replaced by GptModel in Chapter 11.

Parameter list (paramsList). Adam needs to update every learnable number, so we flatten all three matrices into one big list of Value objects. For our sizes that's 27*16 + 8*16 + 27*16 = 992 parameters.

Adam state (MomentumSmoothing through squaredGradAvg). The Adam constants (smoothing factors and Epsilon), plus two double arrays, one for momentum and one for squared gradient averages. Each array has one entry per parameter, all starting at zero.

Backward buffers (topo, visited, backwardStack). Pre-allocated lists/sets/stacks for Backward() to reuse across all 1000 steps. This is the 3-argument Backward() overload we built in Chapter 2. The parameterless version would allocate fresh buffers every call.

Training loop (for (var step ...). This is the heart of training. Each iteration is one step:

  • Pick and tokenize a doc. Cycle through docs with step % docs.Count, wrap with BOS on both sides, cap length at maxSequenceLength. The modulo is defensive: if numSteps ever exceeds docs.Count, the loop wraps back to the start.
  • Forward pass (the posId loop). For each position in the name, run Forward to get logits, softmax to get probabilities, then collect -log(probability of correct next token) as the loss. A 5-character name produces 5 losses.
  • Average losses (loss *= 1.0 / tokenCount). Sum the per-position losses and divide by count to get one scalar loss for the whole document.
  • Backward (loss.Backward(...)). Reset all gradients to zero, clear the buffers, then call Backward(), which fills in .Grad on every Value using the algorithm from Chapter 2.
  • Adam update (the paramsList loop). The five lines from the Adam explanation above:
    1. Compute the decayed learning rate for this step
    2. Update momentum (running avg of gradients)
    3. Update squared gradient average
    4. Apply bias correction to both
    5. Take the step: p.Data -= currentLearningRate * correctedMomentum / sqrt(correctedSquaredGrad + epsilon)
  • Print progress. Log every 100 steps so you can watch the loss go down.

Forward method. Look up token embedding, add position embedding, project to vocab size. Takes the three matrices and embeddingSize as explicit parameters so the forward pass's dependencies are visible at the call site. This is what gets replaced by GptModel.Forward in Chapter 11.

The loop runs 1,000 times, each time nudging the 992 parameters slightly toward something that gives lower loss. By the end, you'll see loss drop from ~3.3 (random guessing) down toward the bigram baseline of ~2.45.

A note on name length. maxSequenceLength = 8 means we train on at most the first 7 characters of each name plus a BOS token. Longer names like "alexandra" or "elizabeth" are silently truncated by the Math.Min on the tokenCount line above. If later on you see the model under-generating long names during inference, this is why. Raising maxSequenceLength to 16 covers ~100% of the dataset but roughly doubles training time, because every position still runs a forward pass. We keep it at 8 for course speed.

Uncomment the Chapter 7 case in the dispatcher in Program.cs:

case "ch7":
    Chapter7Exercise.Run();
    break;
Enter fullscreen mode Exit fullscreen mode

Then run it:

dotnet run -c Release -- ch7
Enter fullscreen mode Exit fullscreen mode

(Release mode matters here. Debug mode is significantly slower because Value allocations dominate. On a modern CPU this runs in under a minute in Release mode.)

What You Should See

The first step prints a loss around 3.3 (random guessing). Over 1,000 steps, the trend moves downward. Don't worry if individual steps bounce around. Each step trains on a single document, so the loss is noisy. One step might land at 1.7, the next at 2.8. What matters is the overall trend, not any single value.

By the end of training, the loss lands in a similar range to the bigram baseline from Chapter 4 (~2.45). That might feel disappointing. Why build a neural network to match a counting table? The answer is that this model still processes each position independently. It has no way for tokens to look at each other, so it's effectively a neural version of the bigram. The components that let the model use longer context (attention in Chapter 9, multi-head attention in Chapter 10, and the MLP in Chapter 10) are what will push the loss well below the bigram baseline when we assemble the full model in Chapter 11.

A note on evaluation. We're computing the loss on the same data we train on. In a production setting, you'd hold out a portion of the data for validation, to detect overfitting (the model memorising training examples rather than learning general patterns). For the purpose of understanding the architecture, this simplification is fine.

Top comments (0)