DEV Community

Cover image for 80. The Transformer: The Architecture That Changed Everything
Akhilesh
Akhilesh

Posted on

80. The Transformer: The Architecture That Changed Everything

In 2017, eight researchers at Google published a paper called "Attention Is All You Need."

The title was a provocation. The dominant view was that you needed recurrent networks to process sequences. You needed memory and sequential processing. The paper said you needed none of that. Just attention, applied in a smart architecture. The result was faster to train, easier to parallelize, and better on nearly every benchmark.

The transformer they described in that paper is the direct ancestor of BERT, GPT-2, GPT-3, GPT-4, Claude, Gemini, and every other large language model that has changed how we interact with computers.

This post builds the full transformer encoder from scratch. You already have all the pieces from the last post. This is where they come together.


The Full Architecture

A transformer has two main components:

Encoder: reads the input sequence and builds a rich contextual representation. BERT is an encoder. Used for understanding tasks: classification, named entity recognition, question answering.

Decoder: generates output tokens one at a time. GPT is a decoder. Used for generation tasks: text completion, translation, summarization.

The original "Attention Is All You Need" transformer had both for translation. Modern LLMs often use one or the other.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
import warnings
warnings.filterwarnings("ignore")

torch.manual_seed(42)

def attention(Q, K, V, mask=None):
    d_k    = Q.shape[-1]
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))
    weights = F.softmax(scores, dim=-1)
    return torch.matmul(weights, V), weights

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_k     = d_model // n_heads
        self.W_q     = nn.Linear(d_model, d_model, bias=False)
        self.W_k     = nn.Linear(d_model, d_model, bias=False)
        self.W_v     = nn.Linear(d_model, d_model, bias=False)
        self.W_o     = nn.Linear(d_model, d_model, bias=False)

    def split_heads(self, x):
        b, s, _ = x.shape
        return x.reshape(b, s, self.n_heads, self.d_k).transpose(1, 2)

    def forward(self, q, k, v, mask=None):
        Q = self.split_heads(self.W_q(q))
        K = self.split_heads(self.W_k(k))
        V = self.split_heads(self.W_v(v))
        out, w = attention(Q, K, V, mask)
        out    = out.transpose(1, 2).reshape(q.shape[0], -1, self.n_heads * self.d_k)
        return self.W_o(out), w

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe  = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.pow(10000.0, torch.arange(0, d_model, 2).float() / d_model)
        pe[:, 0::2] = torch.sin(pos / div)
        pe[:, 1::2] = torch.cos(pos / div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return self.dropout(x + self.pe[:, :x.size(1)])

print("Building blocks assembled. Now constructing the full transformer.")
Enter fullscreen mode Exit fullscreen mode

The Encoder Layer

class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads)
        self.ffn       = FeedForward(d_model, d_ff, dropout)
        self.norm1     = nn.LayerNorm(d_model)
        self.norm2     = nn.LayerNorm(d_model)
        self.dropout   = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_out, _  = self.self_attn(x, x, x, mask)
        x            = self.norm1(x + self.dropout(attn_out))
        ffn_out      = self.ffn(x)
        x            = self.norm2(x + self.dropout(ffn_out))
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, d_ff, n_layers,
                 max_len=512, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_enc   = PositionalEncoding(d_model, max_len, dropout)
        self.layers    = nn.ModuleList([
            EncoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(n_layers)
        ])
        self.norm      = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        x = self.pos_enc(self.embedding(x) * math.sqrt(self.embedding.embedding_dim))
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

VOCAB_SIZE = 10000
D_MODEL    = 256
N_HEADS    = 8
D_FF       = 1024
N_LAYERS   = 6
MAX_LEN    = 512

encoder = TransformerEncoder(VOCAB_SIZE, D_MODEL, N_HEADS, D_FF, N_LAYERS, MAX_LEN)

total_params = sum(p.numel() for p in encoder.parameters())
print(f"Transformer Encoder configuration:")
print(f"  vocab_size: {VOCAB_SIZE:,}")
print(f"  d_model:    {D_MODEL}")
print(f"  n_heads:    {N_HEADS}  (d_k = {D_MODEL//N_HEADS} per head)")
print(f"  d_ff:       {D_FF}")
print(f"  n_layers:   {N_LAYERS}")
print(f"  max_len:    {MAX_LEN}")
print(f"\n  Total parameters: {total_params:,}")
print()

x_ids = torch.randint(1, VOCAB_SIZE, (2, 20))
out   = encoder(x_ids)
print(f"Input shape:  {x_ids.shape}   (batch=2, seq_len=20)")
print(f"Output shape: {out.shape}  (batch=2, seq_len=20, d_model={D_MODEL})")
print()
print("Every token now has a context-aware 256-dim representation.")
print("'bank' in 'river bank' and 'bank account' get different vectors.")
Enter fullscreen mode Exit fullscreen mode

The Decoder Layer

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn  = MultiHeadAttention(d_model, n_heads)
        self.cross_attn = MultiHeadAttention(d_model, n_heads)
        self.ffn        = FeedForward(d_model, d_ff, dropout)
        self.norm1      = nn.LayerNorm(d_model)
        self.norm2      = nn.LayerNorm(d_model)
        self.norm3      = nn.LayerNorm(d_model)
        self.dropout    = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
        self_out, _  = self.self_attn(x, x, x, tgt_mask)
        x            = self.norm1(x + self.dropout(self_out))

        cross_out, _ = self.cross_attn(x, enc_out, enc_out, src_mask)
        x            = self.norm2(x + self.dropout(cross_out))

        ffn_out      = self.ffn(x)
        x            = self.norm3(x + self.dropout(ffn_out))
        return x

print("Decoder Layer has THREE sublayers:")
print()
print("1. MASKED SELF-ATTENTION")
print("   Causal mask: each token only attends to previous tokens.")
print("   Prevents 'cheating' by looking at future tokens during training.")
print()
print("2. CROSS-ATTENTION")
print("   Query from decoder, Key+Value from encoder output.")
print("   Lets decoder attend to the full encoded input.")
print("   How the decoder 'reads' what it is translating/summarizing.")
print()
print("3. FEED-FORWARD")
print("   Same as encoder. Processes each position independently.")
Enter fullscreen mode Exit fullscreen mode

Comparing BERT and GPT Architectures

architectures = {
    "BERT-base": {
        "type":      "Encoder only",
        "layers":    12,
        "d_model":   768,
        "n_heads":   12,
        "params":    "110M",
        "training":  "Masked Language Model + Next Sentence Prediction",
        "use_case":  "Classification, NER, Q&A, Embeddings",
        "attention": "Bidirectional (sees full context)",
    },
    "BERT-large": {
        "type":      "Encoder only",
        "layers":    24,
        "d_model":   1024,
        "n_heads":   16,
        "params":    "340M",
        "training":  "Masked Language Model + Next Sentence Prediction",
        "use_case":  "Same as base, higher accuracy",
        "attention": "Bidirectional",
    },
    "GPT-2": {
        "type":      "Decoder only",
        "layers":    12,
        "d_model":   768,
        "n_heads":   12,
        "params":    "117M",
        "training":  "Next Token Prediction (autoregressive)",
        "use_case":  "Text generation, completion",
        "attention": "Causal (left-to-right only)",
    },
    "GPT-3": {
        "type":      "Decoder only",
        "layers":    96,
        "d_model":   12288,
        "n_heads":   96,
        "params":    "175B",
        "training":  "Next Token Prediction at massive scale",
        "use_case":  "Few-shot learning, generation",
        "attention": "Causal",
    },
    "T5-base": {
        "type":      "Encoder-Decoder",
        "layers":    "12+12",
        "d_model":   768,
        "n_heads":   12,
        "params":    "220M",
        "training":  "Text-to-Text format, span masking",
        "use_case":  "Translation, summarization, Q&A",
        "attention": "Encoder=bidirectional, Decoder=causal",
    },
}

print(f"{'Model':<14} {'Type':<20} {'Layers':>8} {'d_model':>8} {'Params':>8} {'Attention'}")
print("=" * 80)
for name, config in architectures.items():
    print(f"{name:<14} {config['type']:<20} {str(config['layers']):>8} "
          f"{str(config['d_model']):>8} {config['params']:>8} "
          f"{config['attention'][:25]}")
Enter fullscreen mode Exit fullscreen mode

Building a Text Classifier With the Encoder

class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, d_ff,
                 n_layers, num_classes, max_len=256, dropout=0.1):
        super().__init__()
        self.encoder    = TransformerEncoder(
            vocab_size, d_model, n_heads, d_ff, n_layers, max_len, dropout)
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model // 2, num_classes)
        )

    def forward(self, x, mask=None):
        enc_out = self.encoder(x, mask)
        cls_rep = enc_out[:, 0, :]
        return self.classifier(cls_rep)

classifier = TransformerClassifier(
    vocab_size=10000, d_model=128, n_heads=4,
    d_ff=512, n_layers=3, num_classes=2
)

params     = sum(p.numel() for p in classifier.parameters())
batch_ids  = torch.randint(1, 10000, (4, 30))
logits     = classifier(batch_ids)

print(f"Text Classifier (sentiment analysis setup):")
print(f"  Parameters:  {params:,}")
print(f"  Input shape:  {batch_ids.shape}   (4 sentences, 30 tokens each)")
print(f"  Output shape: {logits.shape}  (4 sentences, 2 classes)")
print()
print("The [CLS] token (position 0) aggregates the full sequence.")
print("Its representation is used for classification.")
print("This is exactly how BERT does classification tasks.")
Enter fullscreen mode Exit fullscreen mode

Training a Small Transformer

from torch.utils.data import DataLoader, Dataset

class FakeTextDataset(Dataset):
    def __init__(self, n_samples=1000, seq_len=20, vocab_size=1000):
        self.data = torch.randint(1, vocab_size, (n_samples, seq_len))
        self.labels = torch.randint(0, 2, (n_samples,))

    def __len__(self): return len(self.data)
    def __getitem__(self, i): return self.data[i], self.labels[i]

dataset    = FakeTextDataset(n_samples=2000, seq_len=20, vocab_size=5000)
train_size = int(0.8 * len(dataset))
test_size  = len(dataset) - train_size
train_ds, test_ds = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False)

model_small = TransformerClassifier(
    vocab_size=5000, d_model=64, n_heads=4,
    d_ff=256, n_layers=2, num_classes=2, dropout=0.1
)

device    = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_small = model_small.to(device)
optimizer = torch.optim.AdamW(model_small.parameters(), lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

print(f"Training small transformer ({sum(p.numel() for p in model_small.parameters()):,} params):")
print(f"{'Epoch':>6} {'Train Loss':>12} {'Train Acc':>10} {'Test Acc':>10}")
print("=" * 42)

for epoch in range(10):
    model_small.train()
    total_loss = correct = total = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out  = model_small(x)
        loss = criterion(out, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model_small.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item()
        correct    += out.argmax(1).eq(y).sum().item()
        total      += y.size(0)

    model_small.eval()
    t_correct = t_total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            t_correct += model_small(x).argmax(1).eq(y).sum().item()
            t_total   += y.size(0)

    if (epoch + 1) % 2 == 0:
        print(f"{epoch+1:>6} {total_loss/len(train_loader):>12.4f} "
              f"{correct/total:>10.2%} {t_correct/t_total:>10.2%}")
Enter fullscreen mode Exit fullscreen mode

Layer Normalization vs Batch Normalization

print("Why transformers use LayerNorm, not BatchNorm:")
print()
print("BatchNorm normalizes ACROSS the batch dimension:")
print("  Computes mean/std across all samples in batch for each feature.")
print("  Requires large batch sizes for stable statistics.")
print("  Cannot be used with variable-length sequences easily.")
print("  Behavior changes between train and eval (uses running statistics).")
print()
print("LayerNorm normalizes ACROSS the feature dimension:")
print("  Computes mean/std within each single sample.")
print("  Works with any batch size, even batch_size=1.")
print("  Identical behavior at train and eval time.")
print("  Natural fit for variable-length sequences.")
print()

batch = torch.randn(4, 10, 64)

bn  = nn.BatchNorm1d(64)
ln  = nn.LayerNorm(64)

bn_out = bn(batch.reshape(-1, 64)).reshape(4, 10, 64)
ln_out = ln(batch)

print(f"Input shape:       {batch.shape}")
print(f"BatchNorm output:  {bn_out.shape}  (normalizes across batch×seq dimension)")
print(f"LayerNorm output:  {ln_out.shape}  (normalizes across feature dimension)")
Enter fullscreen mode Exit fullscreen mode

Why the Transformer Won

reasons = {
    "Parallelism": (
        "RNNs: step 3 cannot run until step 2 finishes.\n"
        "  Transformers: all positions computed simultaneously.\n"
        "  Training speedup: 10-100x on modern GPUs."
    ),
    "Long-range dependencies": (
        "RNNs: gradient signal must travel through every step.\n"
        "  Vanishing gradient → cannot learn dependencies >50 steps.\n"
        "  Transformers: any two positions connected by one attention step.\n"
        "  Direct path regardless of distance. No vanishing."
    ),
    "Scalability": (
        "Transformers scale predictably with data and compute.\n"
        "  More data + more parameters = better performance.\n"
        "  This scaling law enabled GPT-3, GPT-4, Claude.\n"
        "  RNNs did not show the same scaling behavior."
    ),
    "Transfer learning": (
        "One pretrained transformer = many fine-tuned tasks.\n"
        "  BERT: pretrain once, fine-tune on any classification task.\n"
        "  GPT: pretrain once, prompt for any generation task.\n"
        "  Efficiency of shared representations."
    ),
}

for reason, explanation in reasons.items():
    print(f"Why transformers won: {reason}")
    for line in explanation.split("\n"):
        print(f"  {line}")
    print()
Enter fullscreen mode Exit fullscreen mode

A Resource Worth Reading

The original paper "Attention Is All You Need" by Vaswani et al. (2017) is required reading. At 15 pages it is completely accessible, explains every architectural decision clearly, and includes training details and ablation studies. The architecture diagram (Figure 1) is now the most reproduced diagram in NLP. Search "Vaswani attention all you need 2017 NeurIPS."

Peter Bloem wrote "Transformers from Scratch" at peterbloem.nl which builds the entire transformer in PyTorch with unusually clear mathematical explanations of why each component exists. More mathematical than Jay Alammar but equally valuable for understanding deeply. Search "Peter Bloem transformers from scratch."


Try This

Create transformer_practice.py.

Build the complete TransformerEncoder from this post from scratch. Do not copy-paste. Type it out to understand what each piece does.

Configure it as a miniature BERT: 4 layers, d_model=128, 4 heads, d_ff=512. Count the parameters. Confirm the shape: input (batch, seq_len) → output (batch, seq_len, 128).

Train it on any text classification dataset (20 Newsgroups, IMDB, or SST-2). Use the [CLS] token (position 0) for classification. Train for 10 epochs. Compare accuracy to a simple BiLSTM trained on the same data.

Implement one additional feature: padding masks. Create a mask that is 0 for padding tokens and 1 for real tokens. Pass it to the attention layers so padding tokens are ignored. Verify the mask works by printing attention weights and confirming padding positions get zero weight.


What's Next

You have built the transformer. Now you use it. BERT shows how to pretrain an encoder on masked language modeling and fine-tune on downstream tasks. It remains one of the most widely deployed NLP models in production. Next post.

Top comments (0)