DEV Community

Cover image for Chapter 10: Multi-Head Attention and the MLP Block
Gary Jackson
Gary Jackson

Posted on • Originally published at garyjackson.dev

Chapter 10: Multi-Head Attention and the MLP Block

What You'll Build

Multi-head attention (running several attention computations in parallel, each on its own slice of the per-token embedding vector) and the MLP block (a two-layer feed-forward network for per-position "thinking"). Both concepts are introduced here and implemented in Model.cs in Chapter 11.

Depends On

Chapters 5, 8, 9 (Helpers, RmsNorm, residual connections, single-head attention).

Why Multiple Heads?

A single attention head can only learn one kind of "what am I looking for?" pattern. With multiple heads, the model can look for different kinds of relationships at the same time. In larger models with bigger embedding dimensions, individual heads often specialise in distinct patterns (one might track syntax, another semantics). At our small scale (headDimension = 4), the specialisation is fuzzier, but the mechanism is the same.

The trick: instead of running 4 full-size attention computations, we split the embedding dimension into 4 slices. If embeddingSize = 16 and headCount = 4, each head operates on 4 dimensions (headDimension = 4). This doesn't lose information because the projections (queryWeights, keyWeights, valueWeights) can learn to put related information into the same slice. The heads compute independently and their outputs are concatenated (not averaged or summed) back to the full embedding size. Concatenation keeps all the per-head information in distinct dimensions, so nothing is lost before the next step.

Multi-Head Attention

// Shape reference - integrated into GptModel.Forward in Chapter 11.
// The for loop is sequential, but conceptually each head is independent.
// In production, all heads are computed in a single matrix multiply on a GPU.

var concatenatedHeads = new List<Value>();

for (int h = 0; h < headCount; h++)
{
    int headStart = h * headDimension;
    List<Value> queryForHead = q.GetRange(headStart, headDimension);

    var attentionLogits = new List<Value>();
    for (int t = 0; t < cachedKeys.Count; t++)
    {
        List<Value> keyForHead = cachedKeys[t].GetRange(headStart, headDimension);
        var dot = new Value(0);
        for (int j = 0; j < headDimension; j++)
        {
            dot += queryForHead[j] * keyForHead[j];
        }

        attentionLogits.Add(dot / Math.Sqrt(headDimension));
    }

    List<Value> attentionWeights = Helpers.Softmax(attentionLogits);

    var headOutput = new List<Value>();
    for (int j = 0; j < headDimension; j++)
    {
        headOutput.Add(new Value(0));
    }

    for (int t = 0; t < cachedValues.Count; t++)
    {
        List<Value> valueForHead = cachedValues[t].GetRange(headStart, headDimension);
        Value w = attentionWeights[t];
        for (int j = 0; j < headDimension; j++)
        {
            headOutput[j] += w * valueForHead[j];
        }
    }

    concatenatedHeads.AddRange(headOutput); // concatenate this head's output
}

// After concatenation, project through outputWeights to mix information across heads
x = Helpers.Linear(concatenatedHeads, outputWeights);
Enter fullscreen mode Exit fullscreen mode

The final Linear(concatenatedHeads, outputWeights) is important. After concatenation, each dimension still belongs to a single head. The outputWeights projection mixes information across heads, letting the model combine what different heads found.

The MLP Block

MLP stands for Multi-Layer Perceptron, a generic term for a stack of linear layers with nonlinearities between them. In transformers it's specifically a two-layer feed-forward network.

Attention is the communication mechanism (tokens talk to each other). The MLP is the computation mechanism (each position "thinks" independently). Concretely, it projects up to 4x the embedding dimension, applies ReLU, then projects back down.

// Shape reference - integrated into GptModel.Forward in Chapter 11.

x = Helpers.Linear(x, mlpUpWeights); // project up: embeddingSize -> 4*embeddingSize
x = [.. x.Select(xi => xi.Relu())]; // nonlinearity
x = Helpers.Linear(x, mlpDownWeights); // project down: 4*embeddingSize -> embeddingSize
Enter fullscreen mode Exit fullscreen mode

Why project up and then back down? The wider intermediate layer gives the model more "room to think" (more dimensions to combine features in) before compressing back to the residual stream size.

We use ReLU here for simplicity. Production transformers typically use smoother variants like GeLU or SwiGLU, but the role is the same: introduce a nonlinearity between the two linear projections.

The Transformer Block

A single transformer layer combines attention and MLP, each wrapped with RMSNorm and a residual connection:

// Shape reference - integrated into GptModel.Forward in Chapter 11.

// Attention with residual
var xResidual = new List<Value>(x);
x = Helpers.RmsNorm(x);
x = /* multi-head attention + outputWeights projection */;
for (int i = 0; i < embeddingSize; i++)
{
    x[i] += xResidual[i];
}

// MLP with residual
xResidual = new List<Value>(x);
x = Helpers.RmsNorm(x);
x = /* MLP block */;
for (int i = 0; i < embeddingSize; i++)
{
    x[i] += xResidual[i];
}
Enter fullscreen mode Exit fullscreen mode

Stacking Blocks

Our model uses layerCount = 1 (a single block), but the architecture supports stacking multiple blocks in sequence. Each block reads from and writes to the same residual stream:

Embeddings
    ↓
┌─ Block 1 ─┐
│ Attention  │
│ MLP        │
└────────────┘
    ↓
┌─ Block 2 ─┐
│ Attention  │
│ MLP        │
└────────────┘
    ↓
   ...
    ↓
┌─ Block N ─┐
│ Attention  │
│ MLP        │
└────────────┘
    ↓
Output projection (lmHead)
Enter fullscreen mode Exit fullscreen mode

Deeper models (more blocks) can learn more complex patterns because each block refines the representation further. GPT-2's largest variant used 48 blocks.

Exercise: Multi-Head Attention + MLP

Like Chapter 9, this exercise uses hand-crafted Q/K/V so you can see the behaviour rather than waiting for training to discover it. The setup: embeddingSize = 8, headCount = 2, headDimension = 4, three cached positions. Head 0's Q slice is aligned with K[1], and head 1's Q slice is aligned with K[2], so the two heads should attend to different positions. After the demo, a second pass runs an MLP block on a fixed input to show the up-project → ReLU → down-project shape change.

Create Chapter10Exercise.cs:

// --- Chapter10Exercise.cs ---

using static MicroGPT.Helpers;

namespace MicroGPT;

public static class Chapter10Exercise
{
    public static void Run()
    {
        MultiHeadAttentionDemo();
        Console.WriteLine();
        MlpBlockDemo();
    }

    // Hand-crafted multi-head attention on a 3-position sequence.
    // embeddingSize = 8, headCount = 2, headDimension = 4. Head 0 and Head 1 are set up to
    // attend to *different* positions so we can see the independence.
    private static void MultiHeadAttentionDemo()
    {
        const int EmbeddingSize = 8;
        const int HeadCount = 2;
        const int HeadDimension = EmbeddingSize / HeadCount;

        // Each cached key has two halves: the first 4 dims serve head 0,
        // the last 4 dims serve head 1. Both halves happen to match here,
        // but they could be completely different - each head only reads its slice.
        var cachedKeys = new List<List<Value>>
        {
            new() { new(1), new(0), new(0), new(0), new(1), new(0), new(0), new(0) }, // K[0]
            new() { new(0), new(1), new(0), new(0), new(0), new(1), new(0), new(0) }, // K[1]
            new() { new(0), new(0), new(1), new(0), new(0), new(0), new(1), new(0) }, // K[2]
        };

        var cachedValues = new List<List<Value>>
        {
            new() { new(10), new(0), new(0), new(0), new(100), new(0), new(0), new(0) }, // V[0]
            new() { new(0), new(20), new(0), new(0), new(0), new(200), new(0), new(0) }, // V[1]
            new() { new(0), new(0), new(30), new(0), new(0), new(0), new(300), new(0) }, // V[2]
        };

        // Q is designed so head 0 matches K[1] and head 1 matches K[2].
        //    head 0 slice                head 1 slice
        var query = new List<Value>
        {
            new(0),
            new(5),
            new(0),
            new(0),
            new(0),
            new(0),
            new(5),
            new(0),
        };

        var concatenatedHeads = new List<Value>();

        for (int h = 0; h < HeadCount; h++)
        {
            int headStart = h * HeadDimension;
            List<Value> queryForHead = query.GetRange(headStart, HeadDimension);

            var attentionLogits = new List<Value>();
            for (int t = 0; t < cachedKeys.Count; t++)
            {
                List<Value> keyForHead = cachedKeys[t].GetRange(headStart, HeadDimension);
                var dot = new Value(0);
                for (int j = 0; j < HeadDimension; j++)
                {
                    dot += queryForHead[j] * keyForHead[j];
                }

                attentionLogits.Add(dot / Math.Sqrt(HeadDimension));
            }

            List<Value> attentionWeights = Softmax(attentionLogits);

            var headOutput = new List<Value>();
            for (int j = 0; j < HeadDimension; j++)
            {
                headOutput.Add(new Value(0));
            }

            for (int t = 0; t < cachedValues.Count; t++)
            {
                List<Value> valueForHead = cachedValues[t].GetRange(headStart, HeadDimension);
                Value w = attentionWeights[t];
                for (int j = 0; j < HeadDimension; j++)
                {
                    headOutput[j] += w * valueForHead[j];
                }
            }

            concatenatedHeads.AddRange(headOutput); // concatenate this head's output

            Console.WriteLine(
                $"--- Head {h} (dims {headStart}..{headStart + HeadDimension - 1}) ---"
            );
            Console.WriteLine(
                $"  Q slice = [{string.Join(", ", queryForHead.Select(v => v.Data))}]"
            );
            for (int t = 0; t < attentionWeights.Count; t++)
            {
                Console.WriteLine($"  attn weight[{t}] = {attentionWeights[t].Data:F4}");
            }

            Console.WriteLine(
                $"  head output = [{string.Join(", ", headOutput.Select(v => v.Data.ToString("F2")))}]"
            );
        }

        Console.WriteLine();
        Console.WriteLine("Concatenated multi-head output (length embeddingSize = 8):");
        Console.WriteLine(
            $"  [{string.Join(", ", concatenatedHeads.Select(v => v.Data.ToString("F2")))}]"
        );
        Console.WriteLine(
            "Note how the first 4 dims are dominated by V[1] and the last 4 by V[2] -"
        );
        Console.WriteLine(
            "the two heads attended to different positions and both contributions survived."
        );
    }

    // Shows the MLP block: up-project -> ReLU -> down-project.
    // We don't train anything here; we just run a fixed input through random weights
    // to show that the shape goes embeddingSize -> 4*embeddingSize -> embeddingSize.
    private static void MlpBlockDemo()
    {
        const int EmbeddingSize = 4;
        var random = new Random(42);

        List<List<Value>> mlpUpWeights = CreateMatrix(random, 4 * EmbeddingSize, EmbeddingSize);
        List<List<Value>> mlpDownWeights = CreateMatrix(random, EmbeddingSize, 4 * EmbeddingSize);

        var x = new List<Value> { new(0.5), new(-0.3), new(1.0), new(-0.8) };
        Console.WriteLine($"--- MLP Block (embeddingSize = {EmbeddingSize}) ---");
        Console.WriteLine($"  input           ({x.Count, 2} dims): [{Format(x)}]");

        List<Value> hidden = Linear(x, mlpUpWeights);
        Console.WriteLine($"  after up-proj   ({hidden.Count, 2} dims): [{Format(hidden)}]");

        var activated = hidden.Select(v => v.Relu()).ToList();
        int negBefore = hidden.Count(v => v.Data < 0);
        Console.WriteLine(
            $"  after ReLU      ({activated.Count, 2} dims): [{Format(activated)}]  (zeroed {negBefore} negatives)"
        );

        List<Value> output = Linear(activated, mlpDownWeights);
        Console.WriteLine($"  after down-proj ({output.Count, 2} dims): [{Format(output)}]");

        static string Format(IEnumerable<Value> vs) =>
            string.Join(", ", vs.Select(v => v.Data.ToString("F3")));
    }
}
Enter fullscreen mode Exit fullscreen mode

Uncomment the Chapter 10 case in the dispatcher in Program.cs:

case "ch10":
    Chapter10Exercise.Run();
    break;
Enter fullscreen mode Exit fullscreen mode

Then run it:

dotnet run -- ch10
Enter fullscreen mode Exit fullscreen mode

You should see head 0's attention peak at position 1, head 1's peak at position 2, and the concatenated output with distinct contributions in each half. The MLP demo shows the dimensionality change: 4 → 16 → 4, with ReLU zeroing out 6 of the 16 intermediate entries (the exact count is deterministic with Random(42)).

This exercise lives in Chapter10Exercise.cs so you can come back to it any time.

Key Distinction: Communication vs. Computation

The transformer alternates between two fundamentally different operations:

  • Attention is communication across time. The token at position t looks at tokens 0..t-1.
  • MLP is computation at a single position. No cross-position information flow.

That's the design pattern of the entire transformer: communicate, compute, communicate, compute, on a residual stream that carries information forward.

Top comments (0)