DEV Community

Vuk Rosić
Vuk Rosić

Posted on

Attention Mechanism Tutorial: From Simple to Advanced

Part 1: The Core Idea

Attention is like a spotlight - it helps models focus on what's important.

import torch
import torch.nn.functional as F

# Simple example: Which word is most important?
sentence = ["I", "love", "pizza"]
importance = torch.tensor([0.1, 0.3, 0.6])  # pizza is most important
Enter fullscreen mode Exit fullscreen mode

Intuition: Instead of treating all words equally, attention assigns different weights to focus on what matters most.

Part 2: Basic Attention Weights

# Raw attention scores (how much to focus on each word)
scores = torch.tensor([2.0, 1.0, 3.0])  # [I, love, pizza]

# Convert to probabilities (softmax)
weights = F.softmax(scores, dim=0)
print(weights)  # [0.24, 0.09, 0.67] - pizza gets most attention
Enter fullscreen mode Exit fullscreen mode

What happened: Softmax converts raw scores to probabilities that sum to 1.

Part 3: Weighted Combination

# Word representations (simplified vectors)
words = torch.tensor([[1.0, 0.0],  # "I"
                      [0.0, 1.0],  # "love" 
                      [1.0, 1.0]]) # "pizza"

# Apply attention weights
attended = torch.sum(weights.unsqueeze(1) * words, dim=0)
print(attended)  # Mostly "pizza" representation
Enter fullscreen mode Exit fullscreen mode

Intuition: We combine all word vectors, but "pizza" contributes most because it has the highest attention weight.

Part 4: Computing Attention Scores

# How similar are words? (dot product)
query = torch.tensor([1.0, 1.0])  # What we're looking for
key1 = torch.tensor([1.0, 0.0])  # "I"
key2 = torch.tensor([0.0, 1.0])  # "love"

score1 = torch.dot(query, key1)  # 1.0
score2 = torch.dot(query, key2)  # 1.0
Enter fullscreen mode Exit fullscreen mode

Theory: Attention scores measure how well a query matches each key.

Part 5: The Q, K, V Concept

# Three roles for each word:
# Q (Query): "What am I looking for?"
# K (Key): "What do I represent?" 
# V (Value): "What information do I carry?"

query = torch.tensor([1.0, 0.0])    # Looking for subject
keys = torch.tensor([[1.0, 0.0],   # "I" - matches query well
                     [0.0, 1.0]])   # "love" - doesn't match
values = torch.tensor([[2.0, 3.0], # "I" carries this info
                       [1.0, 4.0]]) # "love" carries this info
Enter fullscreen mode Exit fullscreen mode

Intuition: Query asks "what do I need?", Keys answer "what do I offer?", Values provide the actual information.

Part 6: One-Line Attention

# Complete attention in one line
attention_output = torch.sum(F.softmax(torch.mv(keys, query), dim=0).unsqueeze(1) * values, dim=0)
Enter fullscreen mode Exit fullscreen mode

What it does: Computes scores (query·keys), applies softmax, weights the values.

Part 7: Self-Attention Intuition

# In self-attention, each word can attend to every other word
sentence = ["The", "cat", "sat"]
# "cat" might attend to "sat" (what did the cat do?)
# "sat" might attend to "cat" (who sat?)
Enter fullscreen mode Exit fullscreen mode

Key insight: Words can look at each other to understand relationships and context.

Part 8: Multi-Head Attention (Simple)

# Multiple "attention heads" look for different things
head1_query = torch.tensor([1.0, 0.0])  # Looking for subjects
head2_query = torch.tensor([0.0, 1.0])  # Looking for actions

# Each head focuses on different aspects
Enter fullscreen mode Exit fullscreen mode

Why multiple heads: Different heads can specialize in different types of relationships (subject-verb, adjective-noun, etc.).

Part 9: Scaling Up

# Real sentences have many words
seq_len = 10  # 10 words in sentence
d_model = 64  # Each word is 64-dimensional vector

# Q, K, V matrices transform word vectors
Q = torch.randn(seq_len, d_model)  # Queries for each word
K = torch.randn(seq_len, d_model)  # Keys for each word
V = torch.randn(seq_len, d_model)  # Values for each word
Enter fullscreen mode Exit fullscreen mode

Scale: Real models use hundreds of dimensions and thousands of words.

Part 10: Attention Matrix

# Attention scores between all word pairs
attention_scores = torch.mm(Q, K.transpose(0, 1))  # [10, 10] matrix
attention_weights = F.softmax(attention_scores, dim=1)  # Each row sums to 1

# Row i, column j = how much word i attends to word j
Enter fullscreen mode Exit fullscreen mode

Visualization: Each row shows where one word "looks" in the sentence.

Part 11: Why Attention Works

# Traditional RNN: Information flows sequentially
# Word 1 → Word 2 → Word 3 → Word 4

# Attention: All words can interact directly
# Word 1 ↔ Word 2 ↔ Word 3 ↔ Word 4
Enter fullscreen mode Exit fullscreen mode

Advantage: No information loss over long distances, parallel processing.

Part 12: Putting It All Together

# Complete self-attention step by step
def simple_attention(X):
    Q = X  # Queries (simplified)
    K = X  # Keys 
    V = X  # Values

    scores = torch.mm(Q, K.transpose(0, 1))  # Compute similarities
    weights = F.softmax(scores, dim=1)        # Convert to probabilities
    output = torch.mm(weights, V)             # Weighted combination
    return output

# Usage
word_vectors = torch.randn(5, 8)  # 5 words, 8 dimensions each
attended_vectors = simple_attention(word_vectors)
Enter fullscreen mode Exit fullscreen mode

Result: Each word vector is now updated with information from all other words, weighted by attention.

Key Takeaways

  1. Attention = Weighted Average: Focus more on important parts
  2. Q·K = Similarity: How well query matches key
  3. Softmax = Probability: Convert scores to weights that sum to 1
  4. Weighted V = Output: Combine values using attention weights
  5. Self-Attention = Words talking to each other: Every word can attend to every other word

This foundation prepares you for transformer models, which are built entirely on attention mechanisms!

Understanding Attention: From Words to Vectors

1. Word Embeddings - The Foundation

import torch
import torch.nn as nn
import torch.nn.functional as F

# Sample sentence: "The cat sat on the mat"
vocab = {"<pad>": 0, "the": 1, "cat": 2, "sat": 3, "on": 4, "mat": 5}
sentence = [1, 2, 3, 4, 1, 5]  # token IDs

# Create embeddings
vocab_size = len(vocab)
embed_dim = 64
embedding = nn.Embedding(vocab_size, embed_dim)

# Convert tokens to vectors
tokens = torch.tensor(sentence)
embeddings = embedding(tokens)
print(f"Shape: {embeddings.shape}")  # [6, 64]
print(f"'cat' vector: {embeddings[1][:8]}...")  # First 8 dimensions
Enter fullscreen mode Exit fullscreen mode

Each word becomes a 64-dimensional vector that captures semantic meaning.

2. The Q, K, V Matrices - Core of Attention

# Attention dimensions
d_model = 64
num_heads = 8
d_k = d_model // num_heads  # 8

# Linear transformations to create Q, K, V
W_q = nn.Linear(d_model, d_model, bias=False)
W_k = nn.Linear(d_model, d_model, bias=False)
W_v = nn.Linear(d_model, d_model, bias=False)

# Transform embeddings
Q = W_q(embeddings)  # Queries: "What am I looking for?"
K = W_k(embeddings)  # Keys: "What do I represent?"
V = W_v(embeddings)  # Values: "What information do I carry?"

print(f"Q shape: {Q.shape}")  # [6, 64]
print(f"K shape: {K.shape}")  # [6, 64]
print(f"V shape: {V.shape}")  # [6, 64]
Enter fullscreen mode Exit fullscreen mode

Intuition:

  • Q (Query): "What information does this word need?"
  • K (Key): "What kind of information does this word offer?"
  • V (Value): "What actual information does this word contain?"

3. Computing Attention Scores

# Reshape for multi-head attention
batch_size, seq_len = 1, 6
Q = Q.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)  # [1, 8, 6, 8]
K = K.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)  # [1, 8, 6, 8]
V = V.view(batch_size, seq_len, num_heads, d_k).transpose(1, 2)  # [1, 8, 6, 8]

# Attention scores: How much should each word pay attention to others?
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
print(f"Attention scores shape: {scores.shape}")  # [1, 8, 6, 6]

# Example: How much does "cat" attend to each word?
cat_attention = scores[0, 0, 1, :]  # First head, "cat" position
words = ["the", "cat", "sat", "on", "the", "mat"]
for i, word in enumerate(words):
    print(f"cat -> {word}: {cat_attention[i]:.3f}")
Enter fullscreen mode Exit fullscreen mode

4. Softmax and Weighted Values

# Convert scores to probabilities
attention_weights = F.softmax(scores, dim=-1)
print(f"Attention weights shape: {attention_weights.shape}")  # [1, 8, 6, 6]

# Apply attention to values
attended_values = torch.matmul(attention_weights, V)  # [1, 8, 6, 8]

# Concatenate heads and project back
attended_values = attended_values.transpose(1, 2).contiguous().view(
    batch_size, seq_len, d_model)  # [1, 6, 64]

print(f"Final attended values shape: {attended_values.shape}")

# Show attention pattern for "cat"
print("\nAttention pattern for 'cat':")
cat_weights = attention_weights[0, 0, 1, :]  # First head
for i, word in enumerate(words):
    print(f"  {word}: {cat_weights[i]:.3f}")
Enter fullscreen mode Exit fullscreen mode

5. Complete Self-Attention Implementation

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()

        # Linear transformations
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)
        attended_values = torch.matmul(attention_weights, V)

        # Concatenate heads
        attended_values = attended_values.transpose(1, 2).contiguous().view(
            batch_size, seq_len, d_model)

        # Final projection
        output = self.W_o(attended_values)
        return output, attention_weights

# Usage
attention = MultiHeadAttention(d_model=64, num_heads=8)
output, weights = attention(embeddings.unsqueeze(0))
print(f"Output shape: {output.shape}")  # [1, 6, 64]
Enter fullscreen mode Exit fullscreen mode

6. Visualizing Attention Patterns

# Extract attention weights for visualization
attention_matrix = weights[0, 0].detach().numpy()  # First head
words = ["the", "cat", "sat", "on", "the", "mat"]

print("Attention Matrix (first head):")
print("From -> To:")
for i, from_word in enumerate(words):
    print(f"{from_word:>4}: ", end="")
    for j, to_word in enumerate(words):
        print(f"{attention_matrix[i,j]:.2f} ", end="")
    print()
Enter fullscreen mode Exit fullscreen mode

7. Key Insights

What happens in attention?

  1. Each word creates a query (what it's looking for)
  2. Each word creates a key (what it represents)
  3. We compute similarity between queries and keys
  4. Higher similarity = more attention
  5. We use attention weights to combine values (actual information)

Example: When processing "cat", the model might:

  • Query: "I need information about animals"
  • Look at all keys: "the" (determiner), "sat" (action), "mat" (object)
  • Pay most attention to "sat" because it's the relevant action
  • Combine information weighted by attention scores

8. Practical Example with Real Meaning

# Sentence: "The cat chased the mouse"
sentence = "The cat chased the mouse"
words = sentence.lower().split()

# Simulate what attention might learn
print("Attention patterns the model might learn:")
print("- 'cat' attends to 'chased' (subject-verb relationship)")
print("- 'chased' attends to 'cat' and 'mouse' (verb-subject-object)")
print("- 'mouse' attends to 'chased' (object-verb relationship)")
print("- 'the' attends to following nouns ('cat', 'mouse')")

# This allows the model to understand:
# - Who did what to whom
# - Grammatical relationships
# - Semantic dependencies
Enter fullscreen mode Exit fullscreen mode

Summary

Attention mechanism allows models to:

  • Focus on relevant parts of the input
  • Relate different words to each other
  • Combine information based on relevance
  • Understand long-range dependencies

The magic is in the learned Q, K, V matrices that transform word embeddings into queries, keys, and values that can interact meaningfully.

Top comments (0)