This is Part 2 of the “Hands-on Transformer Deep Dive” series. We’ll walk step-by-step through modern Transformers’ algorithms and components, and build our own LLM from scratch. If you missed Part 1, check it out here.
In this article, we dive deep into multi-head attention mechanism, a foundational building block of modern Transformers. We’ll look into four of its variants: MHA, MQA, GQA, and MLA, implement them from scratch with only PyTorch, and discuss their characteristics and pros and cons.
Introduction
Multi-head attention allows the model to capture complex patterns by looking at the data from multiple “perspectives”. While single-head attention computes the attention once over the whole input, multi-head attention splits the model’s total feature dimension across multiple heads and run them simultaneously. Each head learns its own query, key, and value projections. Then all heads are combined to recover the full model feature dimension.
Splitting across multiple attention heads allows the model to represent various aspects of the data more effectively. Each head can focus on a smaller, specialized subspace. For example, one head might focus on nearby words to capture phrases or recognize named entities, another head might specialize in understanding relative word positions or long-range semantic links, while still another head could attend to negations or modifiers that change the sentiment or meaning. Combining all heads then recovers the full representation capacity.
Multi-head attention has shown better learning dynamics and improved expressiveness compared to single-head attention. However, different practical constraints, such as memory limitation and inference speed, require trade-offs between model performance, computation cost, and flexibility.
In the following sections, we’ll discuss and implement four of the most popular multi-head attention variants: the classic MHA, Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-head Latent Attention (MLA).
MHA: Multi-Head Attention
The classic MHA is just as what was discussed in the introduction: the attention mechanism is split across multiple attention heads, each learning its own smaller Q, K, V projections (the dimension size is the model dimension divided by number of heads). Then all the heads are combined and the model also learns a final output projection of the original dimension size.
MHA Implementation
Below is an MHA implementation with step-by-step explanation in comments:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0,
f"model dimension (received embed_dim: {embed_dim}) must be divisible \
by the number of attention heads (received num_head: f{num_head})"
self.num_heads = num_heads
self.head_dim = embed_dim / num_heads
# Initialize the query, key, value and final output projections
# shape: (embed_dim, embed_dim)
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_output = nn.Linear(embed_dim, embed_dim)
# Initialize the dropout layer
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Get the batch size, sequence length, embedding dimension from input x
batch_size, seq_len, embed_dim = x.size()
# Step 1. Pass the input through query, key, and value projections
# shape of input x: (batch_size, seq_len, embed_dim)
# shape of projection layer: (embed_dim, embed_dim)
# shape after projection: (batch_size, seq_len, embed_dim)
# Step 2. Split the last dimension into multiple heads
# shape after split: (batch_size, seq_len, num_heads, head_dim)
# Step 3. Rearrange dimension 1 and 2 for parallel computation
# shape after: (batch_size, num_heads, seq_len, head_dim)
queries = self.W_q(x).view(batch_size, seq_len, self.num_heads, \
self.head_dim).transpose(1, 2)
keys = self.W_k(x).view(batch_size, seq_len, self.num_heads, \
self.head_dim).transpose(1, 2)
values = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Step 4. Calculate attention values
# Step 4-1. scaled dot-product attention attn_scores = QK^T/sqrt(d_k)
# Note: since attention is calculated per head,
# we scale by head dimension instead of model dimension
# shape of queries: (batch_size, num_heads, seq_len, head_dim)
# shape of keys transposed: (batch_size, num_heads, head_dim, seq_len)
# shape of attn_scores: (batch_size, num_heads, seq_len, seq_len)
attn_scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Step 4-2. apply mask
if mask is not None:
attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))
# Step 4-3. softmax
attn_scores = F.softmax(attn_scores, dim = -1)
# Step 4-4. dropout
attn_scores = self.dropout(attn_scores)
# Step 4-5. attention values
# shape of attn_scores: (batch_size, num_heads, seq_len, seq_len)
# shape of values: (batch_size, num_heads, seq_len, head_dim)
# shape of attn_values: (batch_size, num_heads, seq_len, head_dim)
attn_values = torch.matmul(attn_scores, values)
# Step 5. Rearrange values dimension and reshape to concatenate heads
# shape after rearrange: (batch_size, seq_len, num_head, head_dim)
# shape after concatenation: (batch_size, seq_len, embed_dim)
attn_values = attn_values.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
# Step 6. Go through the final output projection
# shape of output: (batch_size, seq_len, embed_dim)
output = self.W_output(attn_values)
return output
For a simple demonstration here we included the masked scaled dot-product attention code directly in the forward() method (Step 4 > Step 4–1 to Step 4–5). To have more flexibility choosing from different attention mechanisms, you can abstract the attention implementation away to a separate function or module, then plug it in here.
Here is an example implementation putting it in a separate module. We’ll use it in the MQA, GQA and MLA implementation.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MaskedScaledDotProductAttention(nn.Module):
def __init__(self, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, mask=None):
d_k = queries.size(-1)
attn_scores = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
attn_scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(attn_scores, dim = -1)
attn_weights = self.dropout(attn_weights)
attn_values = torch.matmul(attn_weights, values)
return attn_values
If you’d like to learn more about implementation details of masked scaled dot-product attention, check out Part 1 of this series.
MQA: Multi-Query Attention
While classic MHA delivers richer representations and improves model performance, maintaining each head’s own set of queries, keys, and values greatly increases memory and computation overhead. For example, during the intermediate attention score computation (QK^T), the multi-head attention score tensor is of shape (batch_size, num_heads, seq_len, seq_len), making its size num_heads times as big as its single-head counterpart.
This overhead is especially problematic at inference. To address this, Multi-Query Attention (MQA) was introduced by Noam Shazeer in 2019 to improve efficiency for autoregressive transformer decoders during inference (Shazeer, 2019.)
MQA modifies the MHA architecture by sharing keys and values across all heads, while still allowing each head to have its own queries. At inference, this significantly reduces memory needs and computation overhead, leading to faster token generation with minimal impact on model performance.
MQA implementation
Below is an implementation of MQA using the above MaskedScaledDotProductAttention
module in the attention calculation step. The key differences from MHA are:
- While query projection has the same dimensions as MHA’s query projection (embed_dim, embed_dim), the key and value projections are initialized with only per head dimension (embed_dim, head_dim)
- To share the single key head and value head across multiple query heads, we insert a dummy dimension of 1 at the position of query tensor’s num_heads dimension (aka dimension position 1). At attention computation, PyTorch automatically broadcasts this dimension num_heads times, enabling simultaneous computation of shared keys/values and separate queries per head.
import torch
import torch.nn as nn
class MultiQueryAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super().__init__()
assert embed_dim % num_heads == 0, f'Model hidden dimension (embed_dim) \
must be divisible by number of heads (num_heads). \
Got embed_dim: {embed_dim}, num_heads: {num_heads}.'
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim / num_heads
# initialize linear projections
## Each head has its own Q so the Q projection has
## the shape of full model dimensions
self.W_q = nn.Linear(embed_dim, embed_dim)
## K and V are shared across heads,
## so the projection's second dimension is only head_dim
self.W_k = nn.Linear(embed_dim, self.head_dim)
self.W_v = nn.Linear(embed_dim, self.head_dim)
## Final output projection, also has full model dimensions
self.W_output = nn.Linear(embed_dim, embed_dim)
# Initialize the attention module with dropout
self.attention = MaskedScaledDotProductAttention(dropout)
def forward(self, x, mask=None):
batch_size, seq_len, embed_dim = x.size()
# Step 1. Pass input x through Q projection, split heads and
# rearrange dimensions 1, 2 for parallelism
# shape: (batch_size, num_heads, seq_len, head_dim)
queries = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
# Step 2. Pass input x through K and V projections, insert a dimension at
# position 1 to represent the shared K and V heads
# shape: (batch_size, 1, seq_len, head_dim)
keys = self.W_k(x).unsqueeze(1)
values = self.W_v(x).unsqueeze(1)
# Step 3. Calculate attention
# shape: (batch_size, num_heads, seq_len, head_dim)
attn_values = self.attention(queries, keys, values, mask)
# Step 4. Concatenate attention values across heads
attn_values = attn_values.transpose(1, 2).reshape(batch_size, seq_len, embed_dim)
# Step 5. Pass attention values through the final output projection
output = self.W_output(attn_values)
return output
GQA: Grouped-Query Attention
While MQA greatly improved inference efficiency by saving memory and compute overheads, the simplification of sharing keys and values across all query heads can restrict each head’s expressiveness and their ability to independently “attend” to its own representation subspace. Thus, training models directly with MQA can lead to degraded performance and training instability.
To overcome these limitations, Grouped-Query Attention (GQA) was introduced by Google Research in 2023 (Ainslie et. al, 2023). Instead of sharing one set of key and value heads across all query heads, GQA partitions query heads into groups and let each group share one set of key and value heads. This approach maintains MQA’s efficiency while preserving more of the model’s representational capacity, making it suitable for both training and inference.
GQA has been adopted in several notable LLMs, including LLaMA 2, LLaMA 3, Qwen2 and Qwen3.
GQA Implementation
Below is an implementation of GQA. The key differences from MQA are:
- While the query projection has the same shape as MHA and MQA: (embed_dim, embed_dim), the key and the value projections are of shape (embed_dim, num_kv_groups * head_dim), so that we can easily split them into groups.
- To share keys and values in each query group at attention calculation, we split the key and value tensors into num_kv_groups groups, rearrange the dimensions to align the num_kv_groups with query tensor’s num_heads dimension. Then we repeat the keys and values along the num_kv_groups dimension for heads_per_group (= num_heads / num_kv_groups) times. As num_kv_groups * heads_per_group = num_heads, we can then compute attention of each query head simultaneously just like MHA and MQA.
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
def __init__(self, num_heads, num_kv_groups, embed_dim, dropout=0.1):
assert embed_dim % num_heads == 0, f'Model dimension must \
be divisible by number of heads. Got embed_dim: {embed_dim}, \
num_heads: {num_heads}'
assert num_heads % num_kv_groups == 0, f'Number of heads must be \
divisible by number of KV groups. Got num_heads: {num_heads}, \
num_kv_groups: {num_kv_groups}'
self.num_heads = num_heads
self.head_dim = embed_dim / num_heads
self.num_kv_groups = num_kv_groups
self.groups_per_head = num_heads / num_groups
# initialize Q projection
self.W_q = nn.Linear(embed_dim, embed_dim)
# initialize K, V projections
self.W_k = nn.Linear(embed_dim, embed_dim / num_kv_groups)
self.W_v = nn.Linear(embed_dim, embed_dim / num_kv_groups)
# initialize final output projection
self.W_output = nn.Linear(embed_dim, embed_dim)
# initialize attention with dropout
self.attention = MaskedScaledDotProductAttention(dropout)
def forward(self, x, mask = None):
batch_size, seq_len, embed_dim = x.size()
# Step 1. pass x through Q projection, split heads and rearrange for parallelism
# shape -> (batch_size, num_heads, seq_len, head_dim)
queries = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
# Step 2. pass x through K and V projections, split groups
# shape -> (batch_size, num_kv_groups, seq_len, head_dim)
keys = self.W_k(x).view(batch_size, seq_len, self.num_kv_groups, self.head_dim).transpose(1,2)
values = self.W_v(x).view(batch_size, seq_len, self.num_kv_groups, self.head_dim).transpose(1,2)
# Step 3. repeat keys and values heads_per_group times along the num_kv_groups dimension
# shape: (batch_size, num_kv_groups, seq_len, head_dim) ->
# (batch_size, num_heads(=num_kv_groups * heads_per_group), seq_len, head_dim)
heads_per_group = self.num_heads / self.num_kv_groups
keys = keys.repeat_interleave(heads_per_group, dim=1)
values = values.repeat_interleave(heads_per_group, dim=1)
# Step 4. compute attention
# shape: (batch_size, num_heads, seq_len, head_eim)
attn_values = self.attention(queries, keys, values, mask)
# Step 5. concatenate heads
# shape -> (batch_size, seq_len, embed_dim)
attn_values = attn_values.transpose(1,2).reshape(batch_size, seq_len, embed_dim)
# Step 6. pass attn_values through the final output projection
output = self.W_output(attn_values)
return output
MLA: Multi-head Latent Attention
While GQA strikes a balance between efficiency and quality, further improvements are needed to scale to even larger models and longer, more complex inputs. Multi-head Latent Attention (MLA) was introduced by DeepSeek-AI in their 2024 DeepSeek-v2 paper to address this (DeepSeek-AI, 2024).
MLA uses a low-rank factorization approach to jointly compress keys and values into one much smaller learned latent vector. This compression significantly reduces memory and computation needs for KV-cache, enabling efficient processing of larger and more complex inputs, boosting generation throughput, while maintaining model performance.
Ablation and empirical tests on four hard benchmarks showed that, while MHA outperforms GQA and MQA, MLA performs even better than MHA and requires much smaller amount of KV-cache.
MLA Implementation
Here is the full MLA formula as provided in DeepSeek-v2 paper:
Where:
- : input token embedding at position
- : number of attention heads
- : down-projection matrices for query and key-value content vectors
- : up-projection for query, key and value from content vectors
- : linear projections generating relative queries and keys (before RoPE)
- : output linear projection matrix
- : content query vector (down-projected from input )
- : content key-value vector (also down-projected from input)
- : content queries of all heads / head i
- : relative positional queries of all heads / head i
- : concatenated content and relative query vectors of head i
- : content keys for all heads / head i
- : relative positional keys
- : concatenated content and relative key vectors of head i
- : content values for all heads / head i
- : dimensions of content and relative positional subspaces per head
- : attention output for head i at position t
- : final output
DeepSeek’s MLA is deeply integrated with RoPE (Rotary Position Embedding). We will do a deep dive in positional embedding in the next article and also implement full MLA with RoPE. For now we’ll just implement a simplified version without RoPE to demonstrate MLA’s idea of learning compressed latent content vectors instead of full Q, K, V projections.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiheadLatentAttentionSimplified(nn.Module):
def __init__(self, embed_dim, num_heads, q_latent_dim, kv_latent_dim, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_head
self.q_latent_dim = q_latent_dim
self.kv_latent_dim = kv_latent_dim
# Initialize projections
self.W_DQ = nn.Linear(embed_dim, q_latent_dim)
self.W_UQ = nn.Linear(q_latent_dim, embed_dim)
self.W_DKV = nn.Linear(embed_dim, kv_latent_dim)
self.W_UK = nn.Linear(kv_latent_dim, embed_dim)
self.W_UV = nn.Linear(kv_latent_dim, embed_dim)
self.W_output = nn.Linear(embed_dim, embed_dim)
# Initialize attention module
self.attention = MaskedScaledDotProductAttention(dropout)
def forward(self, x, mask = None):
batch_size, seq_len, _ = x.shape
# Step 1. Compress and decompress Q
c_q = self.W_DQ(x) # (batch_size, seq_len, q_latent_dim)
q_content = self.W_UQ(c_q) # (batch_size, seq_len, embed_dim)
# Step 2. Compres K and V into one latent subspace
c_kv = self.W_DKV(x) # (batch_size, seq_len, kv_latent_dim)
# Step 3. Decompress K and V respectively
k_content = self.W_UK(c_kv) # (batch_size, seq_len, embed_dim)
v_content = self.W_UV(c_kv) # (batch_size, seq_len, embed_dim)
# Step 4. Split heads and reshape for multi-head attention
# -> (batch_size, num_heads, seq_len, head_dim)
queries = q_content.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
keys = k_content.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
values = v_content.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1,2)
# Step 5. Apply attention
# -> (batch_size, num_heads, seq_len, head_dim)
attn_output = self.attention(queries, keys, values, mask)
# Step 6. Concatenate heads and reshape
attn_output = attn_output.transpose(1,2).reshape(batch_size, seq_len, self.embed_dim)
# Step 7. Apply final output projection
output = self.W_output(attn_output)
return output, c_kv # return c_kv to show that it will be cached
References
- Noam Shazeer. (2019). Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150
- Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, Sumit Sanghai (2023). GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. arXiv:2305.13245
- DeepSeek-AI (2024). DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model. arXiv:2405.04434
Top comments (0)