Transformers revolutionized artificial intelligence! From BERT to GPT to ChatGPT, this architecture powers virtually every major breakthrough in natural language processing. But how do they actually work under the hood? Today, I'll walk you through building a complete Transformer from scratch using PyTorch, demystifying the "Attention Is All You Need" paper with practical code and clear explanations.
๐ฎ Why Transformers Changed Everything
Before Transformers, we had RNNs and LSTMs that processed sequences one word at a time - like reading a book with a narrow flashlight. Transformers said: "What if we could see the entire page at once?"
This parallel processing breakthrough enabled:
- โก Massive Parallelization: Train on modern GPUs efficiently
- ๐ Long-Range Dependencies: Connect words across entire documents
- ๐ฏ Attention Mechanism: Focus on relevant parts dynamically
- ๐ Scalability: Build increasingly larger and more capable models
๐งฎ The Mathematics Behind the Magic
The Heart: Scaled Dot-Product Attention
The core innovation is surprisingly elegant:
Attention(Q, K, V) = softmax(QK^T / โd_k)V
Think of it like a recommendation system:
- Q (Query): "What am I looking for?"
- K (Key): "What's available to look at?"
- V (Value): "What information do I actually get?"
Multi-Head Attention: Multiple Perspectives
Instead of one attention mechanism, use many in parallel:
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
It's like having multiple experts, each focusing on different aspects of the input!
Positional Encoding: Teaching Position
Since attention has no inherent notion of order, we inject position information:
PE(pos, 2i) = sin(pos/10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos/10000^(2i/d_model))
Clever use of sine and cosine functions creates unique "fingerprints" for each position.
๐๏ธ Architecture Implementation
Let's build each component step by step:
1. Multi-Head Attention
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections for Q, K, V
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def scaled_dot_product_attention(self, Q, K, V, mask=None):
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
# Apply mask if provided
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Softmax and apply to values
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# Linear transformations and reshape for multi-head
Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Apply attention
attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)
# Concatenate heads and apply output projection
attention_output = attention_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model)
return self.W_o(attention_output)
2. Positional Encoding
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_length=5000):
super().__init__()
self.dropout = nn.Dropout(0.1)
# Create positional encoding matrix
pe = torch.zeros(max_length, d_model)
position = torch.arange(0, max_length).unsqueeze(1).float()
# Apply sine and cosine functions
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
-(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Register as buffer (not a parameter)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
# Add positional encoding to embeddings
return x + self.pe[:x.size(0), :]
3. Complete Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
# Multi-head attention
self.attention = MultiHeadAttention(d_model, num_heads)
# Layer normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Feed-forward network
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual connection
attention_output = self.attention(x, x, x, mask)
x = self.norm1(x + self.dropout(attention_output))
# Feed-forward with residual connection
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
4. Complete Transformer Model
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_length, num_classes):
super().__init__()
self.d_model = d_model
# Input processing
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_length)
# Stack of Transformer blocks
self.transformer_blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
# Output processing
self.norm = nn.LayerNorm(d_model)
self.classifier = nn.Linear(d_model, num_classes)
self.dropout = nn.Dropout(0.1)
def forward(self, x, mask=None):
# Embedding with scaling
x = self.embedding(x) * math.sqrt(self.d_model)
# Add positional encoding
x = self.positional_encoding(x)
x = self.dropout(x)
# Pass through Transformer blocks
for transformer in self.transformer_blocks:
x = transformer(x, mask)
# Final processing
x = self.norm(x)
x = torch.mean(x, dim=1) # Global average pooling
return self.classifier(x)
๐ Training Results & Performance
Our implementation achieves impressive results on sentiment analysis:
Key Metrics
- Test Accuracy: 85%+ on movie review classification
- Model Size: ~200K parameters
- Training Time: ~10 minutes on Apple M4
- Architecture: 4 layers, 8 heads, 128 dimensions
- Convergence: Stable training without overfitting
Performance Highlights
- โก Fast Training: Parallel processing beats RNNs by orders of magnitude
- ๐ฏ Great Accuracy: Competitive performance on sentiment analysis
- ๐ง Flexible Architecture: Easy to scale up or adapt for different tasks
- ๐ Stable Training: Consistent convergence across multiple runs
๐ง Key Implementation Insights
1. Attention Heads Capture Different Patterns
Each attention head learns to focus on different types of relationships:
- Head 1 might focus on syntactic dependencies
- Head 2 might capture semantic relationships
- Head 3 might look at positional patterns
2. Residual Connections Are Critical
Without residual connections, deep Transformers suffer from vanishing gradients:
# This is crucial!
x = self.norm1(x + self.dropout(attention_output))
3. Layer Normalization Placement Matters
We use "Pre-LN" (normalization before attention) for better training stability:
# Pre-normalization helps with gradient flow
attention_output = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))
4. Positional Encoding Is Learnable
While we use sinusoidal encoding, you can also use learned embeddings:
# Alternative: learnable positional embeddings
self.pos_embedding = nn.Embedding(max_length, d_model)
๐ฏ Real-World Applications
This foundational implementation opens doors to:
๐ค Language Models
- GPT-style: Decoder-only for text generation
- BERT-style: Encoder-only for understanding tasks
- T5-style: Encoder-decoder for text-to-text
๐ง Practical Tasks
- Sentiment Analysis: Movie reviews, product feedback
- Text Classification: Spam detection, topic categorization
- Named Entity Recognition: Extract people, places, organizations
- Question Answering: Build intelligent assistants
๐ฌ Research Directions
- Efficient Attention: Linear attention, sparse attention
- Vision Transformers: Apply to image classification
- Multimodal Models: Combine text, images, and audio
- Scientific Applications: Protein folding, drug discovery
๐ฎ Advanced Extensions
Ready to take it further? Here are exciting directions:
# Encoder-Decoder for Translation
class TransformerEncoderDecoder(nn.Module):
def __init__(self, src_vocab, tgt_vocab, d_model, num_heads, num_layers):
# Full seq2seq implementation
pass
# Vision Transformer for Images
class VisionTransformer(nn.Module):
def __init__(self, image_size, patch_size, num_classes):
# Apply Transformers to image patches
pass
# Efficient Attention Variants
class LinearAttention(nn.Module):
def __init__(self, d_model, num_heads):
# O(n) complexity instead of O(nยฒ)
pass
๐ก Key Takeaways
- Attention is Revolutionary: Parallel processing transforms sequence modeling
- Simple Components, Powerful Combinations: Basic building blocks create sophisticated behavior
- Mathematics Drives Innovation: Understanding theory enables better applications
- Implementation is Accessible: Complex papers become manageable code
- Foundation for the Future: Basis for GPT, BERT, and beyond
๐ Try It Yourself!
Ready to dive in? Here's how to get started:
GitHub Repository
git clone https://github.com/GruheshKurra/TransformersFromScratch.git
cd TransformersFromScratch
pip install torch matplotlib numpy pandas scikit-learn seaborn tqdm
jupyter notebook Transformers.ipynb
Hugging Face Model Hub
๐ค karthik-2905/TransformersFromScratch
Explore the complete implementation, trained models, and interactive examples!
๐ What's Next?
This implementation provides a solid foundation for:
- Understanding modern NLP architectures
- Building production systems
- Conducting research in attention mechanisms
- Creating domain-specific applications
The Transformer revolution is just getting started. Whether you're building the next ChatGPT or exploring novel applications in science and creativity, understanding these fundamentals will serve you well.
๐ค Connect & Contribute
Found this helpful? Let's push the boundaries of AI together!
- ๐ GitHub: GruheshKurra
- ๐ค Hugging Face: karthik-2905
Have questions, ideas, or want to contribute? Open an issue or submit a PR!
Happy coding, and may your attention weights be well-aligned! ๐ฎโจ
Top comments (0)