DEV Community

Cover image for Building Transformers from Scratch: Understanding the Architecture That Changed AI

Building Transformers from Scratch: Understanding the Architecture That Changed AI

Transformers revolutionized artificial intelligence! From BERT to GPT to ChatGPT, this architecture powers virtually every major breakthrough in natural language processing. But how do they actually work under the hood? Today, I'll walk you through building a complete Transformer from scratch using PyTorch, demystifying the "Attention Is All You Need" paper with practical code and clear explanations.

๐Ÿ”ฎ Why Transformers Changed Everything

Before Transformers, we had RNNs and LSTMs that processed sequences one word at a time - like reading a book with a narrow flashlight. Transformers said: "What if we could see the entire page at once?"

This parallel processing breakthrough enabled:

  • โšก Massive Parallelization: Train on modern GPUs efficiently
  • ๐Ÿ”— Long-Range Dependencies: Connect words across entire documents
  • ๐ŸŽฏ Attention Mechanism: Focus on relevant parts dynamically
  • ๐Ÿ“ˆ Scalability: Build increasingly larger and more capable models

Training Results

๐Ÿงฎ The Mathematics Behind the Magic

The Heart: Scaled Dot-Product Attention

The core innovation is surprisingly elegant:

Attention(Q, K, V) = softmax(QK^T / โˆšd_k)V
Enter fullscreen mode Exit fullscreen mode

Think of it like a recommendation system:

  • Q (Query): "What am I looking for?"
  • K (Key): "What's available to look at?"
  • V (Value): "What information do I actually get?"

Multi-Head Attention: Multiple Perspectives

Instead of one attention mechanism, use many in parallel:

MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
Enter fullscreen mode Exit fullscreen mode

It's like having multiple experts, each focusing on different aspects of the input!

Positional Encoding: Teaching Position

Since attention has no inherent notion of order, we inject position information:

PE(pos, 2i) = sin(pos/10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos/10000^(2i/d_model))
Enter fullscreen mode Exit fullscreen mode

Clever use of sine and cosine functions creates unique "fingerprints" for each position.

๐Ÿ—๏ธ Architecture Implementation

Let's build each component step by step:

1. Multi-Head Attention

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Linear projections for Q, K, V
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)  
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply mask if provided
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # Softmax and apply to values
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)

        return output, attention_weights

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # Linear transformations and reshape for multi-head
        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Apply attention
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # Concatenate heads and apply output projection
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model)

        return self.W_o(attention_output)
Enter fullscreen mode Exit fullscreen mode

2. Positional Encoding

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_length=5000):
        super().__init__()
        self.dropout = nn.Dropout(0.1)

        # Create positional encoding matrix
        pe = torch.zeros(max_length, d_model)
        position = torch.arange(0, max_length).unsqueeze(1).float()

        # Apply sine and cosine functions
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register as buffer (not a parameter)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # Add positional encoding to embeddings
        return x + self.pe[:x.size(0), :]
Enter fullscreen mode Exit fullscreen mode

3. Complete Transformer Block

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # Multi-head attention
        self.attention = MultiHeadAttention(d_model, num_heads)

        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        # Feed-forward network
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attention_output = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout(attention_output))

        # Feed-forward with residual connection
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))

        return x
Enter fullscreen mode Exit fullscreen mode

4. Complete Transformer Model

class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_length, num_classes):
        super().__init__()
        self.d_model = d_model

        # Input processing
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_length)

        # Stack of Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

        # Output processing
        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, num_classes)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x, mask=None):
        # Embedding with scaling
        x = self.embedding(x) * math.sqrt(self.d_model)

        # Add positional encoding
        x = self.positional_encoding(x)
        x = self.dropout(x)

        # Pass through Transformer blocks
        for transformer in self.transformer_blocks:
            x = transformer(x, mask)

        # Final processing
        x = self.norm(x)
        x = torch.mean(x, dim=1)  # Global average pooling

        return self.classifier(x)
Enter fullscreen mode Exit fullscreen mode

๐Ÿš€ Training Results & Performance

Our implementation achieves impressive results on sentiment analysis:

Key Metrics

  • Test Accuracy: 85%+ on movie review classification
  • Model Size: ~200K parameters
  • Training Time: ~10 minutes on Apple M4
  • Architecture: 4 layers, 8 heads, 128 dimensions
  • Convergence: Stable training without overfitting

Performance Highlights

  • โšก Fast Training: Parallel processing beats RNNs by orders of magnitude
  • ๐ŸŽฏ Great Accuracy: Competitive performance on sentiment analysis
  • ๐Ÿ”ง Flexible Architecture: Easy to scale up or adapt for different tasks
  • ๐Ÿ“Š Stable Training: Consistent convergence across multiple runs

๐Ÿง  Key Implementation Insights

1. Attention Heads Capture Different Patterns

Each attention head learns to focus on different types of relationships:

  • Head 1 might focus on syntactic dependencies
  • Head 2 might capture semantic relationships
  • Head 3 might look at positional patterns

2. Residual Connections Are Critical

Without residual connections, deep Transformers suffer from vanishing gradients:

# This is crucial!
x = self.norm1(x + self.dropout(attention_output))
Enter fullscreen mode Exit fullscreen mode

3. Layer Normalization Placement Matters

We use "Pre-LN" (normalization before attention) for better training stability:

# Pre-normalization helps with gradient flow
attention_output = self.attention(self.norm1(x), self.norm1(x), self.norm1(x))
Enter fullscreen mode Exit fullscreen mode

4. Positional Encoding Is Learnable

While we use sinusoidal encoding, you can also use learned embeddings:

# Alternative: learnable positional embeddings
self.pos_embedding = nn.Embedding(max_length, d_model)
Enter fullscreen mode Exit fullscreen mode

๐ŸŽฏ Real-World Applications

This foundational implementation opens doors to:

๐Ÿค– Language Models

  • GPT-style: Decoder-only for text generation
  • BERT-style: Encoder-only for understanding tasks
  • T5-style: Encoder-decoder for text-to-text

๐Ÿ”ง Practical Tasks

  • Sentiment Analysis: Movie reviews, product feedback
  • Text Classification: Spam detection, topic categorization
  • Named Entity Recognition: Extract people, places, organizations
  • Question Answering: Build intelligent assistants

๐Ÿ”ฌ Research Directions

  • Efficient Attention: Linear attention, sparse attention
  • Vision Transformers: Apply to image classification
  • Multimodal Models: Combine text, images, and audio
  • Scientific Applications: Protein folding, drug discovery

๐Ÿ”ฎ Advanced Extensions

Ready to take it further? Here are exciting directions:

# Encoder-Decoder for Translation
class TransformerEncoderDecoder(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model, num_heads, num_layers):
        # Full seq2seq implementation
        pass

# Vision Transformer for Images  
class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes):
        # Apply Transformers to image patches
        pass

# Efficient Attention Variants
class LinearAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        # O(n) complexity instead of O(nยฒ)
        pass
Enter fullscreen mode Exit fullscreen mode

๐Ÿ’ก Key Takeaways

  1. Attention is Revolutionary: Parallel processing transforms sequence modeling
  2. Simple Components, Powerful Combinations: Basic building blocks create sophisticated behavior
  3. Mathematics Drives Innovation: Understanding theory enables better applications
  4. Implementation is Accessible: Complex papers become manageable code
  5. Foundation for the Future: Basis for GPT, BERT, and beyond

๐Ÿš€ Try It Yourself!

Ready to dive in? Here's how to get started:

GitHub Repository

git clone https://github.com/GruheshKurra/TransformersFromScratch.git
cd TransformersFromScratch
pip install torch matplotlib numpy pandas scikit-learn seaborn tqdm
jupyter notebook Transformers.ipynb
Enter fullscreen mode Exit fullscreen mode

Hugging Face Model Hub

๐Ÿค— karthik-2905/TransformersFromScratch

Explore the complete implementation, trained models, and interactive examples!

๐ŸŒŸ What's Next?

This implementation provides a solid foundation for:

  • Understanding modern NLP architectures
  • Building production systems
  • Conducting research in attention mechanisms
  • Creating domain-specific applications

The Transformer revolution is just getting started. Whether you're building the next ChatGPT or exploring novel applications in science and creativity, understanding these fundamentals will serve you well.

๐Ÿค Connect & Contribute

Found this helpful? Let's push the boundaries of AI together!

Have questions, ideas, or want to contribute? Open an issue or submit a PR!


Happy coding, and may your attention weights be well-aligned! ๐Ÿ”ฎโœจ

Top comments (0)