In 2017, Google published a paper called "Attention Is All You Need."
Before it, NLP was dominated by RNNs and LSTMs. They processed text word by word, left to right. They had memory problems. They were slow to train. They forgot long-range context.
The transformer threw all of that out. No recurrence. No convolutions. Just attention. And it worked better than everything before it, trained faster on GPUs, and scaled in ways nobody expected.
GPT, BERT, T5, LLaMA, Claude, ChatGPT. All of them are transformers. Understanding the architecture is understanding the foundation of modern AI.
What You'll Learn Here
- Why transformers replaced RNNs
- The encoder and decoder: what each one does
- Multi-head attention: the mechanism from Post 90, applied at scale
- Positional encoding: how the model knows word order
- Layer normalization and residual connections
- Feed-forward layers inside the transformer
- How to build a mini transformer in PyTorch
The Problem Transformers Solved
RNNs process sequences one token at a time. To understand word 100, you had to go through words 1 through 99 first. Three problems:
Slow training. You can't parallelize a sequential process. GPUs sit idle.
Forgetting. Long sequences cause gradients to vanish. The model struggles to connect a word at position 1 to a word at position 100.
Bottleneck. In encoder-decoder RNNs, the entire input sequence is compressed into one fixed-size vector. Summarizing "War and Peace" into one vector loses information.
The transformer's answer: process all positions simultaneously using attention. Every token attends to every other token in one matrix operation. Fully parallelizable. No forgetting. No bottleneck.
The Big Picture: Encoder and Decoder
The original transformer has two parts:
Encoder: reads the input sequence and builds a rich representation of it. Used for understanding tasks (classification, question answering).
Decoder: generates the output sequence one token at a time, attending to both the encoder's output and its own previous outputs. Used for generation tasks (translation, summarization, text generation).
Some models use only the encoder (BERT). Some use only the decoder (GPT). Some use both (T5, original translation models).
Input sequence → [Encoder] → Context representations
↓
Target (so far) → [Decoder] → Output token
The Transformer Building Blocks
Let's build each piece and assemble them.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
1. Positional Encoding
Attention has no sense of order. "The cat sat" and "sat cat The" would look identical to pure attention. Positional encoding injects position information into the token embeddings.
The original paper uses sine and cosine functions at different frequencies:
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_len=512, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# Create positional encoding matrix
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len).unsqueeze(1).float()
# Frequencies for each dimension
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # even dimensions
pe[:, 1::2] = torch.cos(position * div_term) # odd dimensions
pe = pe.unsqueeze(0) # (1, max_seq_len, d_model)
self.register_buffer('pe', pe) # not a parameter, but saved with model
def forward(self, x):
# x: (batch, seq_len, d_model)
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
# Show what positional encoding looks like
pe_demo = PositionalEncoding(d_model=16, max_seq_len=20)
x_demo = torch.zeros(1, 20, 16) # batch=1, seq=20, d_model=16
pe_out = pe_demo(x_demo)
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 4))
plt.imshow(pe_out[0].detach().numpy(), aspect='auto', cmap='RdBu')
plt.colorbar()
plt.xlabel('Embedding Dimension')
plt.ylabel('Position')
plt.title('Positional Encoding (each row = one position in sequence)')
plt.savefig('positional_encoding.png', dpi=100)
plt.show()
print("Each position gets a unique pattern of sine/cosine values")
print(f"Position 0: {pe_out[0, 0, :8].detach().numpy().round(3)}")
print(f"Position 1: {pe_out[0, 1, :8].detach().numpy().round(3)}")
print(f"Position 10: {pe_out[0, 10, :8].detach().numpy().round(3)}")
2. Multi-Head Attention
We covered single-head attention in Post 90. Multi-head runs several attention operations in parallel, each with different learned projections. Each head learns to attend to different aspects: syntax in one head, semantics in another, long-range dependencies in another.
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # dimension per head
# Linear projections for Q, K, V and output
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# Q, K, V: (batch, n_heads, seq_len, d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
return torch.matmul(attn_weights, V), attn_weights
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# Linear projections
Q = self.W_q(Q) # (batch, seq_len, d_model)
K = self.W_k(K)
V = self.W_v(V)
# Split into n_heads
# (batch, seq_len, d_model) -> (batch, n_heads, seq_len, d_k)
Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# Attention for each head
attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads
# (batch, n_heads, seq_len, d_k) -> (batch, seq_len, d_model)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# Final linear projection
return self.W_o(attn_output), attn_weights
# Test it
mha = MultiHeadAttention(d_model=64, n_heads=8)
x = torch.randn(2, 10, 64) # batch=2, seq_len=10, d_model=64
output, weights = mha(x, x, x) # self-attention: Q=K=V=x
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights: {weights.shape} = (batch, n_heads, seq, seq)")
Output:
Input shape: torch.Size([2, 10, 64])
Output shape: torch.Size([2, 10, 64])
Attention weights: torch.Size([2, 8, 10, 10]) = (batch, n_heads, seq, seq)
8 heads. Each one produces a 10x10 attention map showing which tokens attend to which. The outputs are concatenated and projected back to d_model=64.
3. Feed-Forward Layer
After attention, each position passes through a small feed-forward network independently. It's the same network applied to every position, adding capacity to transform the attention output.
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff) # expand
self.linear2 = nn.Linear(d_ff, d_model) # contract
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(self.dropout(F.relu(self.linear1(x))))
# Typically d_ff = 4 * d_model
ff = FeedForward(d_model=64, d_ff=256)
x = torch.randn(2, 10, 64)
print(f"FF input: {x.shape}")
print(f"FF output: {ff(x).shape}")
This expand-contract structure (d_model → 4*d_model → d_model) gives each position a private computation step after the global mixing of attention.
4. Layer Normalization and Residual Connections
Every sublayer (attention, feed-forward) is wrapped in two things:
Residual connection: adds the input back to the output. This is x + sublayer(x). It keeps gradients flowing cleanly through deep networks.
Layer normalization: normalizes the activations across the embedding dimension. Stabilizes training.
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual + layer norm
attn_out, _ = self.self_attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attn_out)) # Add & Norm
# Feed-forward with residual + layer norm
ff_out = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_out)) # Add & Norm
return x
# Test encoder layer
enc_layer = EncoderLayer(d_model=64, n_heads=8, d_ff=256)
x = torch.randn(2, 10, 64)
out = enc_layer(x)
print(f"Encoder layer: {x.shape} -> {out.shape}")
5. The Decoder Layer
The decoder has three sublayers instead of two:
- Masked self-attention: attends to previously generated tokens only (not future tokens, because at inference time you don't have them yet)
- Cross-attention: attends to the encoder's output (connects input to output)
- Feed-forward: same as encoder
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.masked_self_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.cross_attention = MultiHeadAttention(d_model, n_heads, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
# 1. Masked self-attention (decoder attends to itself)
self_attn_out, _ = self.masked_self_attn(x, x, x, tgt_mask)
x = self.norm1(x + self.dropout(self_attn_out))
# 2. Cross-attention (decoder attends to encoder output)
cross_attn_out, attn_weights = self.cross_attention(
x, encoder_output, encoder_output, src_mask
)
x = self.norm2(x + self.dropout(cross_attn_out))
# 3. Feed-forward
ff_out = self.feed_forward(x)
x = self.norm3(x + self.dropout(ff_out))
return x, attn_weights
6. The Causal Mask
During decoder self-attention, position i should only attend to positions 0 through i. Not future positions. This is called the causal or autoregressive mask.
def make_causal_mask(seq_len):
# Lower triangular mask
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq, seq)
mask = make_causal_mask(seq_len=5)
print("Causal mask (1 = can attend, 0 = cannot):")
print(mask[0, 0].int())
Output:
Causal mask (1 = can attend, 0 = cannot):
tensor([[1, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 1, 0],
[1, 1, 1, 1, 1]])
Position 0 sees only itself. Position 4 sees all previous positions and itself. Never the future.
7. The Full Transformer
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model=256, n_heads=8,
n_encoder_layers=3, n_decoder_layers=3,
d_ff=1024, max_seq_len=512, dropout=0.1):
super().__init__()
# Embeddings
self.src_embedding = nn.Embedding(vocab_size, d_model)
self.tgt_embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
# Encoder stack
self.encoder_layers = nn.ModuleList([
EncoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_encoder_layers)
])
# Decoder stack
self.decoder_layers = nn.ModuleList([
DecoderLayer(d_model, n_heads, d_ff, dropout)
for _ in range(n_decoder_layers)
])
# Output projection: d_model -> vocab_size
self.output_projection = nn.Linear(d_model, vocab_size)
# Weight initialization
self._init_weights()
def _init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def encode(self, src, src_mask=None):
x = self.pos_encoding(self.src_embedding(src))
for layer in self.encoder_layers:
x = layer(x, src_mask)
return x
def decode(self, tgt, encoder_output, src_mask=None, tgt_mask=None):
x = self.pos_encoding(self.tgt_embedding(tgt))
for layer in self.decoder_layers:
x, attn = layer(x, encoder_output, src_mask, tgt_mask)
return x
def forward(self, src, tgt, src_mask=None, tgt_mask=None):
encoder_output = self.encode(src, src_mask)
decoder_output = self.decode(tgt, encoder_output, src_mask, tgt_mask)
return self.output_projection(decoder_output)
# Build it
vocab_size = 1000
model = Transformer(
vocab_size=vocab_size,
d_model=128,
n_heads=4,
n_encoder_layers=2,
n_decoder_layers=2,
d_ff=512
)
n_params = sum(p.numel() for p in model.parameters())
print(f"Transformer parameters: {n_params:,}")
# Test forward pass
batch_size = 2
src_len = 12
tgt_len = 8
src = torch.randint(0, vocab_size, (batch_size, src_len))
tgt = torch.randint(0, vocab_size, (batch_size, tgt_len))
# Causal mask for decoder
tgt_mask = make_causal_mask(tgt_len)
output = model(src, tgt, tgt_mask=tgt_mask)
print(f"Input src: {src.shape}")
print(f"Input tgt: {tgt.shape}")
print(f"Output: {output.shape} = (batch, tgt_len, vocab_size)")
Output:
Transformer parameters: 3,416,232
Transformer input src: torch.Size([2, 12])
Input tgt: torch.Size([2, 8])
Output: torch.Size([2, 8, 1000]) = (batch, tgt_len, vocab_size)
For each position in the target sequence, we get a probability distribution over the entire vocabulary. The highest probability token is the model's prediction for that position.
Encoder-Only vs Decoder-Only vs Encoder-Decoder
Modern models specialize:
Encoder-Only (BERT, RoBERTa):
- Reads full input bidirectionally
- Every token attends to every other token
- Best for: classification, NER, question answering
- Not for: text generation
Decoder-Only (GPT, LLaMA, Claude):
- Causal: each token only attends to previous tokens
- Best for: text generation, chat, code completion
- Fine-tuned for: instruction following (RLHF)
Encoder-Decoder (T5, BART, original transformer):
- Encoder reads input fully
- Decoder generates output attending to encoder
- Best for: translation, summarization, seq2seq tasks
# Encoder-only: just the encoder stack
class EncoderOnlyTransformer(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, n_classes):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model)
self.layers = nn.ModuleList([
EncoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)
])
self.classifier = nn.Linear(d_model, n_classes)
def forward(self, x, mask=None):
x = self.pos_encoding(self.embedding(x))
for layer in self.layers:
x = layer(x, mask)
# Pool: use [CLS] token (position 0) for classification
cls_output = x[:, 0, :]
return self.classifier(cls_output)
# Decoder-only: just the decoder self-attention (no cross-attention)
class DecoderOnlyTransformer(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model)
# Encoder layers but with causal masking = decoder-only
self.layers = nn.ModuleList([
EncoderLayer(d_model, n_heads, d_ff) for _ in range(n_layers)
])
self.output = nn.Linear(d_model, vocab_size)
def forward(self, x):
seq_len = x.size(1)
mask = make_causal_mask(seq_len).to(x.device)
x = self.pos_encoding(self.embedding(x))
for layer in self.layers:
x = layer(x, mask)
return self.output(x) # (batch, seq_len, vocab_size)
Scaling Laws: Why Bigger Works Better
One of the most important empirical discoveries about transformers: performance improves predictably as you scale up.
Parameters | Tokens trained | Rough capability
----------- --------------- ------------------
125M (GPT-2 small) | 40B | Basic coherence
1.3B | 300B | Decent paragraphs
7B (LLaMA-2) | 2T | Strong reasoning
13B | 2T | Near GPT-3.5
70B | 2T | Competitive
175B (GPT-3) | 300B | Remarkable breadth
The transformer architecture scales. Add more layers, more heads, wider d_model, train on more data. Performance keeps improving, mostly predictably.
This is called the scaling hypothesis and it's why companies are racing to build larger models.
Quick Cheat Sheet
| Component | What it does |
|---|---|
| Token embedding | Maps each token ID to a d_model vector |
| Positional encoding | Adds position info via sine/cosine patterns |
| Self-attention | Every token attends to every other token |
| Multi-head attention | Parallel attention with different learned projections |
| Causal mask | Prevents decoder from seeing future tokens |
| Cross-attention | Decoder attends to encoder's output |
| Feed-forward | Per-position transformation after attention |
| Layer norm | Stabilizes activations across embedding dimension |
| Residual connection | x + sublayer(x), keeps gradients flowing |
| Output projection | Maps d_model back to vocab_size for next-token prediction |
Practice Challenges
Level 1:
Build the EncoderOnlyTransformer with vocab_size=5000, d_model=64, n_heads=4, n_layers=2. Pass a random batch of token IDs through it. Print the output shape and verify it matches (batch_size, n_classes).
Level 2:
Visualize the causal mask for seq_len=10. Then visualize the attention weights from a multi-head attention layer on a random sequence. Do different heads attend to different positions?
Level 3:
Train the DecoderOnlyTransformer on a character-level language model. Use any short text (Shakespeare works well). Train it to predict the next character. After training, sample from it by repeatedly predicting the next token and feeding it back in.
Top comments (0)