Part 1: The Core Idea
Attention is like a spotlight - it helps models focus on what's important.
import torch
import torch.nn.functional as F
# Simple example: Which word is most important?
sentence = ["I", "love", "pizza"]
importance = torch.tensor([0.1, 0.3, 0.6]) # pizza is most important
Intuition: Instead of treating all words equally, attention assigns different weights to focus on what matters most.
Part 2: Basic Attention Weights
# Raw attention scores (how much to focus on each word)
scores = torch.tensor([2.0, 1.0, 3.0]) # [I, love, pizza]
# Convert to probabilities (softmax)
weights = F.softmax(scores, dim=0)
print(weights) # [0.24, 0.09, 0.67] - pizza gets most attention
What happened: Softmax converts raw scores to probabilities that sum to 1.
Part 3: Weighted Combination
# Word representations (simplified vectors)
words = torch.tensor([[1.0, 0.0], # "I"
[0.0, 1.0], # "love"
[1.0, 1.0]]) # "pizza"
# Apply attention weights
attended = torch.sum(weights.unsqueeze(1) * words, dim=0)
print(attended) # Mostly "pizza" representation
Intuition: We combine all word vectors, but "pizza" contributes most because it has the highest attention weight.
Part 4: Computing Attention Scores
# How similar are words? (dot product)
query = torch.tensor([1.0, 1.0]) # What we're looking for
key1 = torch.tensor([1.0, 0.0]) # "I"
key2 = torch.tensor([0.0, 1.0]) # "love"
score1 = torch.dot(query, key1) # 1.0
score2 = torch.dot(query, key2) # 1.0
Theory: Attention scores measure how well a query matches each key.
Part 5: The Q, K, V Concept
# Three roles for each word:
# Q (Query): "What am I looking for?"
# K (Key): "What do I represent?"
# V (Value): "What information do I carry?"
query = torch.tensor([1.0, 0.0]) # Looking for subject
keys = torch.tensor([[1.0, 0.0], # "I" - matches query well
[0.0, 1.0]]) # "love" - doesn't match
values = torch.tensor([[2.0, 3.0], # "I" carries this info
[1.0, 4.0]]) # "love" carries this info
Intuition: Query asks "what do I need?", Keys answer "what do I offer?", Values provide the actual information.
Part 6: One-Line Attention
# Complete attention in one line
attention_output = torch.sum(F.softmax(torch.mv(keys, query), dim=0).unsqueeze(1) * values, dim=0)
What it does: Computes scores (query·keys), applies softmax, weights the values.
Part 7: Self-Attention Intuition
# In self-attention, each word can attend to every other word
sentence = ["The", "cat", "sat"]
# "cat" might attend to "sat" (what did the cat do?)
# "sat" might attend to "cat" (who sat?)
Key insight: Words can look at each other to understand relationships and context.
Part 8: Multi-Head Attention (Simple)
# Multiple "attention heads" look for different things
head1_query = torch.tensor([1.0, 0.0]) # Looking for subjects
head2_query = torch.tensor([0.0, 1.0]) # Looking for actions
# Each head focuses on different aspects
Why multiple heads: Different heads can specialize in different types of relationships (subject-verb, adjective-noun, etc.).
Part 9: Scaling Up
# Real sentences have many words
seq_len = 10 # 10 words in sentence
d_model = 64 # Each word is 64-dimensional vector
# Q, K, V matrices transform word vectors
Q = torch.randn(seq_len, d_model) # Queries for each word
K = torch.randn(seq_len, d_model) # Keys for each word
V = torch.randn(seq_len, d_model) # Values for each word
Scale: Real models use hundreds of dimensions and thousands of words.
Part 10: Attention Matrix
# Attention scores between all word pairs
attention_scores = torch.mm(Q, K.transpose(0, 1)) # [10, 10] matrix
attention_weights = F.softmax(attention_scores, dim=1) # Each row sums to 1
# Row i, column j = how much word i attends to word j
Visualization: Each row shows where one word "looks" in the sentence.
Part 11: Why Attention Works
# Traditional RNN: Information flows sequentially
# Word 1 → Word 2 → Word 3 → Word 4
# Attention: All words can interact directly
# Word 1 ↔ Word 2 ↔ Word 3 ↔ Word 4
Advantage: No information loss over long distances, parallel processing.
Part 12: Putting It All Together
# Complete self-attention step by step
def simple_attention(X):
Q = X # Queries (simplified)
K = X # Keys
V = X # Values
scores = torch.mm(Q, K.transpose(0, 1)) # Compute similarities
weights = F.softmax(scores, dim=1) # Convert to probabilities
output = torch.mm(weights, V) # Weighted combination
return output
# Usage
word_vectors = torch.randn(5, 8) # 5 words, 8 dimensions each
attended_vectors = simple_attention(word_vectors)
Result: Each word vector is now updated with information from all other words, weighted by attention.
Key Takeaways
- Attention = Weighted Average: Focus more on important parts
- Q·K = Similarity: How well query matches key
- Softmax = Probability: Convert scores to weights that sum to 1
- Weighted V = Output: Combine values using attention weights
- Self-Attention = Words talking to each other: Every word can attend to every other word
This foundation prepares you for transformer models, which are built entirely on attention mechanisms!
Understanding Attention: From Words to Vectors
1. Word Embeddings - The Foundation
import torch
import torch.nn as nn
import torch.nn.functional as F
# Sample sentence: "The cat sat on the mat"
vocab = {"<pad>": 0, "the": 1, "cat": 2, "sat": 3, "on": 4, "mat": 5}
sentence = [1, 2, 3, 4, 1, 5] # token IDs
# Create embeddings
vocab_size = len(vocab)
embed_dim = 64
embedding = nn.Embedding(vocab_size, embed_dim)
# Convert tokens to vectors
tokens = torch.tensor(sentence)
embeddings = embedding(tokens)
print(f"Shape: {embeddings.shape}") # [6, 64]
print(f"'cat' vector: {embeddings[1][:8]}...") # First 8 dimensions
Each word becomes a 64-dimensional vector that captures semantic meaning.
2. The Q, K, V Matrices - Core of Attention
# Attention dimensions
d_model = 64
num_heads = 8
d_k = d_model // num_heads # 8
# Linear transformations to create Q, K, V
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)
# Transform embeddings
Q = W_q(embeddings) # Queries: "What am I looking for?"
K = W_k(embeddings) # Keys: "What do I represent?"
V = W_v(embeddings) # Values: "What information do I carry?"
print(f"Q shape: {Q.shape}") # [6, 64]
print(f"K shape: {K.shape}") # [6, 64]
print(f"V shape: {V.shape}") # [6, 64]
Intuition:
- Q (Query): "What information does this word need?"
- K (Key): "What kind of information does this word offer?"
- V (Value): "What actual information does this word contain?"
3. Computing Attention Scores
# Reshape for multi-head attention
batch_size, seq_len = 1, 6
Q = Q.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2) # [1, 8, 6, 8]
K = K.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2) # [1, 8, 6, 8]
V = V.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2) # [1, 8, 6, 8]
# Attention scores: How much should each word pay attention to others?
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
print(f"Attention scores shape: {scores.shape}") # [1, 8, 6, 6]
# Example: How much does "cat" attend to each word?
cat_attention = scores[0, 0, 1, :] # First head, "cat" position
words = ["the", "cat", "sat", "on", "the", "mat"]
for i, word in enumerate(words):
print(f"cat -> {word}: {cat_attention[i]:.3f}")
4. Softmax and Weighted Values
# Convert scores to probabilities
attention_weights = F.softmax(scores, dim=-1)
print(f"Attention weights shape: {attention_weights.shape}") # [1, 8, 6, 6]
# Apply attention to values
attended_values = torch.matmul(attention_weights, V) # [1, 8, 6, 8]
# Concatenate heads and project back
attended_values = attended_values.transpose(1, 2).contiguous().view(
batch_size, seq_len, d_model) # [1, 6, 64]
print(f"Final attended values shape: {attended_values.shape}")
# Show attention pattern for "cat"
print("\nAttention pattern for 'cat':")
cat_weights = attention_weights[0, 0, 1, :] # First head
for i, word in enumerate(words):
print(f" {word}: {cat_weights[i]:.3f}")
5. Complete Self-Attention Implementation
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
batch_size, seq_len, d_model = x.size()
# Linear transformations
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
attention_weights = F.softmax(scores, dim=-1)
attended_values = torch.matmul(attention_weights, V)
# Concatenate heads
attended_values = attended_values.transpose(1, 2).contiguous().view(
batch_size, seq_len, d_model)
# Final projection
output = self.W_o(attended_values)
return output, attention_weights
# Usage
attention = MultiHeadAttention(d_model=64, num_heads=8)
output, weights = attention(embeddings.unsqueeze(0))
print(f"Output shape: {output.shape}") # [1, 6, 64]
6. Visualizing Attention Patterns
# Extract attention weights for visualization
attention_matrix = weights[0, 0].detach().numpy() # First head
words = ["the", "cat", "sat", "on", "the", "mat"]
print("Attention Matrix (first head):")
print("From -> To:")
for i, from_word in enumerate(words):
print(f"{from_word:>4}: ", end="")
for j, to_word in enumerate(words):
print(f"{attention_matrix[i,j]:.2f} ", end="")
print()
7. Key Insights
What happens in attention?
- Each word creates a query (what it's looking for)
- Each word creates a key (what it represents)
- We compute similarity between queries and keys
- Higher similarity = more attention
- We use attention weights to combine values (actual information)
Example: When processing "cat", the model might:
- Query: "I need information about animals"
- Look at all keys: "the" (determiner), "sat" (action), "mat" (object)
- Pay most attention to "sat" because it's the relevant action
- Combine information weighted by attention scores
8. Practical Example with Real Meaning
# Sentence: "The cat chased the mouse"
sentence = "The cat chased the mouse"
words = sentence.lower().split()
# Simulate what attention might learn
print("Attention patterns the model might learn:")
print("- 'cat' attends to 'chased' (subject-verb relationship)")
print("- 'chased' attends to 'cat' and 'mouse' (verb-subject-object)")
print("- 'mouse' attends to 'chased' (object-verb relationship)")
print("- 'the' attends to following nouns ('cat', 'mouse')")
# This allows the model to understand:
# - Who did what to whom
# - Grammatical relationships
# - Semantic dependencies
Summary
Attention mechanism allows models to:
- Focus on relevant parts of the input
- Relate different words to each other
- Combine information based on relevance
- Understand long-range dependencies
The magic is in the learned Q, K, V matrices that transform word embeddings into queries, keys, and values that can interact meaningfully.
Top comments (0)