DEV Community

Cover image for 91. The Transformer Architecture: The Invention That Changed AI
Akhilesh
Akhilesh

Posted on

91. The Transformer Architecture: The Invention That Changed AI

In 2017, Google published a paper called "Attention Is All You Need."

Before it, NLP was dominated by RNNs and LSTMs. They processed text word by word, left to right. They had memory problems. They were slow to train. They forgot long-range context.

The transformer threw all of that out. No recurrence. No convolutions. Just attention. And it worked better than everything before it, trained faster on GPUs, and scaled in ways nobody expected.

GPT, BERT, T5, LLaMA, Claude, ChatGPT. All of them are transformers. Understanding the architecture is understanding the foundation of modern AI.


What You'll Learn Here

  • Why transformers replaced RNNs
  • The encoder and decoder: what each one does
  • Multi-head attention: the mechanism from Post 90, applied at scale
  • Positional encoding: how the model knows word order
  • Layer normalization and residual connections
  • Feed-forward layers inside the transformer
  • How to build a mini transformer in PyTorch

The Problem Transformers Solved

RNNs process sequences one token at a time. To understand word 100, you had to go through words 1 through 99 first. Three problems:

Slow training. You can't parallelize a sequential process. GPUs sit idle.

Forgetting. Long sequences cause gradients to vanish. The model struggles to connect a word at position 1 to a word at position 100.

Bottleneck. In encoder-decoder RNNs, the entire input sequence is compressed into one fixed-size vector. Summarizing "War and Peace" into one vector loses information.

The transformer's answer: process all positions simultaneously using attention. Every token attends to every other token in one matrix operation. Fully parallelizable. No forgetting. No bottleneck.


The Big Picture: Encoder and Decoder

The original transformer has two parts:

Encoder: reads the input sequence and builds a rich representation of it. Used for understanding tasks (classification, question answering).

Decoder: generates the output sequence one token at a time, attending to both the encoder's output and its own previous outputs. Used for generation tasks (translation, summarization, text generation).

Some models use only the encoder (BERT). Some use only the decoder (GPT). Some use both (T5, original translation models).

Input sequence → [Encoder] → Context representations
                                      ↓
Target (so far) → [Decoder] → Output token
Enter fullscreen mode Exit fullscreen mode

The Transformer Building Blocks

Let's build each piece and assemble them.

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
Enter fullscreen mode Exit fullscreen mode

1. Positional Encoding

Attention has no sense of order. "The cat sat" and "sat cat The" would look identical to pure attention. Positional encoding injects position information into the token embeddings.

The original paper uses sine and cosine functions at different frequencies:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len=512, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Create positional encoding matrix
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()

        # Frequencies for each dimension
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)  # even dimensions
        pe[:, 1::2] = torch.cos(position * div_term)  # odd dimensions

        pe = pe.unsqueeze(0)  # (1, max_seq_len, d_model)
        self.register_buffer('pe', pe)  # not a parameter, but saved with model

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

# Show what positional encoding looks like
pe_demo = PositionalEncoding(d_model=16, max_seq_len=20)
x_demo  = torch.zeros(1, 20, 16)  # batch=1, seq=20, d_model=16
pe_out  = pe_demo(x_demo)

import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.imshow(pe_out[0].detach().numpy(), aspect='auto', cmap='RdBu')
plt.colorbar()
plt.xlabel('Embedding Dimension')
plt.ylabel('Position')
plt.title('Positional Encoding (each row = one position in sequence)')
plt.savefig('positional_encoding.png', dpi=100)
plt.show()

print("Each position gets a unique pattern of sine/cosine values")
print(f"Position 0:  {pe_out[0, 0, :8].detach().numpy().round(3)}")
print(f"Position 1:  {pe_out[0, 1, :8].detach().numpy().round(3)}")
print(f"Position 10: {pe_out[0, 10, :8].detach().numpy().round(3)}")
Enter fullscreen mode Exit fullscreen mode

2. Multi-Head Attention

We covered single-head attention in Post 90. Multi-head runs several attention operations in parallel, each with different learned projections. Each head learns to attend to different aspects: syntax in one head, semantics in another, long-range dependencies in another.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k     = d_model // n_heads  # dimension per head

        # Linear projections for Q, K, V and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Q, K, V: (batch, n_heads, seq_len, d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        return torch.matmul(attn_weights, V), attn_weights

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        # Linear projections
        Q = self.W_q(Q)  # (batch, seq_len, d_model)
        K = self.W_k(K)
        V = self.W_v(V)

        # Split into n_heads
        # (batch, seq_len, d_model) -> (batch, n_heads, seq_len, d_k)
        Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)

        # Attention for each head
        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate heads
        # (batch, n_heads, seq_len, d_k) -> (batch, seq_len, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)

        # Final linear projection
        return self.W_o(attn_output), attn_weights

# Test it
mha = MultiHeadAttention(d_model=64, n_heads=8)
x   = torch.randn(2, 10, 64)  # batch=2, seq_len=10, d_model=64

output, weights = mha(x, x, x)  # self-attention: Q=K=V=x
print(f"Input shape:          {x.shape}")
print(f"Output shape:         {output.shape}")
print(f"Attention weights:    {weights.shape}  = (batch, n_heads, seq, seq)")
Enter fullscreen mode Exit fullscreen mode

Output:

Input shape:          torch.Size([2, 10, 64])
Output shape:         torch.Size([2, 10, 64])
Attention weights:    torch.Size([2, 8, 10, 10])  = (batch, n_heads, seq, seq)
Enter fullscreen mode Exit fullscreen mode

8 heads. Each one produces a 10x10 attention map showing which tokens attend to which. The outputs are concatenated and projected back to d_model=64.


3. Feed-Forward Layer

After attention, each position passes through a small feed-forward network independently. It's the same network applied to every position, adding capacity to transform the attention output.

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)   # expand
        self.linear2 = nn.Linear(d_ff, d_model)    # contract
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

# Typically d_ff = 4 * d_model
ff = FeedForward(d_model=64, d_ff=256)
x  = torch.randn(2, 10, 64)
print(f"FF input:  {x.shape}")
print(f"FF output: {ff(x).shape}")
Enter fullscreen mode Exit fullscreen mode

This expand-contract structure (d_model → 4*d_model → d_model) gives each position a private computation step after the global mixing of attention.


4. Layer Normalization and Residual Connections

Every sublayer (attention, feed-forward) is wrapped in two things:

Residual connection: adds the input back to the output. This is x + sublayer(x). It keeps gradients flowing cleanly through deep networks.

Layer normalization: normalizes the activations across the embedding dimension. Stabilizes training.

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()

        self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward   = FeedForward(d_model, d_ff, dropout)

        self.norm1   = nn.LayerNorm(d_model)
        self.norm2   = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention with residual + layer norm
        attn_out, _ = self.self_attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_out))   # Add & Norm

        # Feed-forward with residual + layer norm
        ff_out = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_out))     # Add & Norm

        return x

# Test encoder layer
enc_layer = EncoderLayer(d_model=64, n_heads=8, d_ff=256)
x = torch.randn(2, 10, 64)
out = enc_layer(x)
print(f"Encoder layer: {x.shape} -> {out.shape}")
Enter fullscreen mode Exit fullscreen mode

5. The Decoder Layer

The decoder has three sublayers instead of two:

  1. Masked self-attention: attends to previously generated tokens only (not future tokens, because at inference time you don't have them yet)
  2. Cross-attention: attends to the encoder's output (connects input to output)
  3. Feed-forward: same as encoder
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()

        self.masked_self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attention  = MultiHeadAttention(d_model, n_heads, dropout)
        self.feed_forward     = FeedForward(d_model, d_ff, dropout)

        self.norm1   = nn.LayerNorm(d_model)
        self.norm2   = nn.LayerNorm(d_model)
        self.norm3   = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # 1. Masked self-attention (decoder attends to itself)
        self_attn_out, _ = self.masked_self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(self_attn_out))

        # 2. Cross-attention (decoder attends to encoder output)
        cross_attn_out, attn_weights = self.cross_attention(
            x, encoder_output, encoder_output, src_mask
        )
        x = self.norm2(x + self.dropout(cross_attn_out))

        # 3. Feed-forward
        ff_out = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_out))

        return x, attn_weights
Enter fullscreen mode Exit fullscreen mode

6. The Causal Mask

During decoder self-attention, position i should only attend to positions 0 through i. Not future positions. This is called the causal or autoregressive mask.

def make_causal_mask(seq_len):
    # Lower triangular mask
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask.unsqueeze(0).unsqueeze(0)  # (1, 1, seq, seq)

mask = make_causal_mask(seq_len=5)
print("Causal mask (1 = can attend, 0 = cannot):")
print(mask[0, 0].int())
Enter fullscreen mode Exit fullscreen mode

Output:

Causal mask (1 = can attend, 0 = cannot):
tensor([[1, 0, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]])
Enter fullscreen mode Exit fullscreen mode

Position 0 sees only itself. Position 4 sees all previous positions and itself. Never the future.


7. The Full Transformer

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8,
                 n_encoder_layers=3, n_decoder_layers=3,
                 d_ff=1024, max_seq_len=512, dropout=0.1):
        super().__init__()

        # Embeddings
        self.src_embedding = nn.Embedding(vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding  = PositionalEncoding(d_model, max_seq_len, dropout)

        # Encoder stack
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_encoder_layers)
        ])

        # Decoder stack
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_decoder_layers)
        ])

        # Output projection: d_model -> vocab_size
        self.output_projection = nn.Linear(d_model, vocab_size)

        # Weight initialization
        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def encode(self, src, src_mask=None):
        x = self.pos_encoding(self.src_embedding(src))
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x

    def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
        x = self.pos_encoding(self.tgt_embedding(tgt))
        for layer in self.decoder_layers:
            x, attn = layer(x, encoder_output, src_mask, tgt_mask)
        return x

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        encoder_output = self.encode(src, src_mask)
        decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
        return self.output_projection(decoder_output)

# Build it
vocab_size = 1000
model = Transformer(
    vocab_size=vocab_size,
    d_model=128,
    n_heads=4,
    n_encoder_layers=2,
    n_decoder_layers=2,
    d_ff=512
)

n_params = sum(p.numel() for p in model.parameters())
print(f"Transformer parameters: {n_params:,}")

# Test forward pass
batch_size = 2
src_len    = 12
tgt_len    = 8

src = torch.randint(0, vocab_size, (batch_size, src_len))
tgt = torch.randint(0, vocab_size, (batch_size, tgt_len))

# Causal mask for decoder
tgt_mask = make_causal_mask(tgt_len)

output = model(src, tgt, tgt_mask=tgt_mask)
print(f"Input src:  {src.shape}")
print(f"Input tgt:  {tgt.shape}")
print(f"Output:     {output.shape}  = (batch, tgt_len, vocab_size)")
Enter fullscreen mode Exit fullscreen mode

Output:

Transformer parameters: 3,416,232
Transformer input src:  torch.Size([2, 12])
Input tgt:  torch.Size([2, 8])
Output:     torch.Size([2, 8, 1000])  = (batch, tgt_len, vocab_size)
Enter fullscreen mode Exit fullscreen mode

For each position in the target sequence, we get a probability distribution over the entire vocabulary. The highest probability token is the model's prediction for that position.


Encoder-Only vs Decoder-Only vs Encoder-Decoder

Modern models specialize:

Encoder-Only (BERT, RoBERTa):
  - Reads full input bidirectionally
  - Every token attends to every other token
  - Best for: classification, NER, question answering
  - Not for: text generation

Decoder-Only (GPT, LLaMA, Claude):
  - Causal: each token only attends to previous tokens
  - Best for: text generation, chat, code completion
  - Fine-tuned for: instruction following (RLHF)

Encoder-Decoder (T5, BART, original transformer):
  - Encoder reads input fully
  - Decoder generates output attending to encoder
  - Best for: translation, summarization, seq2seq tasks
Enter fullscreen mode Exit fullscreen mode
# Encoder-only: just the encoder stack
class EncoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, n_classes):
        super().__init__()
        self.embedding    = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        self.layers       = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)
        ])
        self.classifier = nn.Linear(d_model, n_classes)

    def forward(self, x, mask=None):
        x = self.pos_encoding(self.embedding(x))
        for layer in self.layers:
            x = layer(x, mask)
        # Pool: use [CLS] token (position 0) for classification
        cls_output = x[:, 0, :]
        return self.classifier(cls_output)

# Decoder-only: just the decoder self-attention (no cross-attention)
class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff):
        super().__init__()
        self.embedding    = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        # Encoder layers but with causal masking = decoder-only
        self.layers       = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)
        ])
        self.output = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        seq_len  = x.size(1)
        mask     = make_causal_mask(seq_len).to(x.device)
        x        = self.pos_encoding(self.embedding(x))
        for layer in self.layers:
            x = layer(x, mask)
        return self.output(x)  # (batch, seq_len, vocab_size)
Enter fullscreen mode Exit fullscreen mode

Scaling Laws: Why Bigger Works Better

One of the most important empirical discoveries about transformers: performance improves predictably as you scale up.

Parameters  | Tokens trained | Rough capability
-----------   ---------------  ------------------
125M (GPT-2 small) | 40B    | Basic coherence
1.3B               | 300B   | Decent paragraphs
7B (LLaMA-2)       | 2T     | Strong reasoning
13B                | 2T     | Near GPT-3.5
70B                | 2T     | Competitive
175B (GPT-3)       | 300B   | Remarkable breadth
Enter fullscreen mode Exit fullscreen mode

The transformer architecture scales. Add more layers, more heads, wider d_model, train on more data. Performance keeps improving, mostly predictably.

This is called the scaling hypothesis and it's why companies are racing to build larger models.


Quick Cheat Sheet

Component What it does
Token embedding Maps each token ID to a d_model vector
Positional encoding Adds position info via sine/cosine patterns
Self-attention Every token attends to every other token
Multi-head attention Parallel attention with different learned projections
Causal mask Prevents decoder from seeing future tokens
Cross-attention Decoder attends to encoder's output
Feed-forward Per-position transformation after attention
Layer norm Stabilizes activations across embedding dimension
Residual connection x + sublayer(x), keeps gradients flowing
Output projection Maps d_model back to vocab_size for next-token prediction

Practice Challenges

Level 1:
Build the EncoderOnlyTransformer with vocab_size=5000, d_model=64, n_heads=4, n_layers=2. Pass a random batch of token IDs through it. Print the output shape and verify it matches (batch_size, n_classes).

Level 2:
Visualize the causal mask for seq_len=10. Then visualize the attention weights from a multi-head attention layer on a random sequence. Do different heads attend to different positions?

Level 3:
Train the DecoderOnlyTransformer on a character-level language model. Use any short text (Shakespeare works well). Train it to predict the next character. After training, sample from it by repeatedly predicting the next token and feeding it back in.


References

Top comments (0)