Compressing KV cache via low-rank projections - the attention mechanism behind DeepSeek-V2/V3 and Kimi K2.x
Why This Matters
Multi-Head Latent Attention (MLA) is the attention variant that replaces standard Multi-Head Attention (MHA) in DeepSeek-V2, DeepSeek-V3, and Kimi K2.x models. Instead of caching full KV pairs per head, MLA projects them into a low-dimensional latent space, achieving 3-5× KV cache compression with minimal quality loss.
- MLA changes how prefix caching, chunked prefill, and paged attention must be implemented
Formal Definition
Standard Multi-Head Attention (MHA)
For input , MHA computes per-head projections:
where , , .
KV cache size per token: elements.
MLA: Low-Rank Latent Projection
MLA replaces the per-head KV projections with a shared low-rank latent compression:
Compression (KV → Latent):
where is the down-projection matrix and .
Decompression (Latent → KV):
where and are up-projection matrices.
KV cache per token: Only is stored - a single vector of dimension .
Query Compression (Optional)
MLA also compresses queries for training efficiency:
This doesn't affect the KV cache but reduces the activation memory during training.
Rotary Position Embedding (RoPE) Handling
RoPE is applied to the decompressed queries and keys. To keep the KV cache small, MLA applies RoPE to a separate "absorbed" key projection:
where with is a narrow projection that carries positional information. The cached representation remains (position-agnostic), and the RoPE key is recomputed at attention time from the cached latent.
Compression Ratio
For a model with heads and head dimension :
The total MLA KV cache per token stores:
- The latent vector of dimension
- The decoupled RoPE keys for all heads:
Approximating DeepSeek-V3 dimensions (see paper Table 1 for exact values):
,
,
,
:
Real-world impact: With 3.76× less cache per layer, a 60-layer model at 64K context drops from ~1TB to ~266GB of KV cache. That's the difference between needing a DGX H200 cluster versus a single 8×H100 node.
Note: Reported "5-10×" compression figures in literature often come from models with smaller (e.g., 32 or fewer) or measure only the component. The practical ratio depends on the specific dimension choices.
Core Concepts
1. Weight Absorption (The Key Trick)
The critical insight in MLA is that the up-projection matrices can be absorbed into the query projection during attention computation:
Substituting the decompressed forms:
If we define , then:
This means the attention score can be computed directly from the latent representations, avoiding explicit decompression of K and V for the score computation. However, the V decompression is still needed for the output.
Practical implication: During decoding, we can compute attention scores without materializing the full K matrix. Only V needs decompression after softmax.
2. Decoupled RoPE Strategy
RoPE requires position-dependent keys, which conflicts with caching a position-agnostic latent. MLA solves this with a decoupled key:
- Content key: - cached in latent form
- Position key: - small, position-aware, must be cached separately
The attention score becomes:
Practical implication: The KV cache stores both (latent) and (decoupled rope key). Total cache per token: .
3. MLA vs GQA vs MHA
| Property | MHA | GQA | MLA |
|---|---|---|---|
| KV groups | 1 (latent) | ||
| Cache per token | |||
| Quality | Baseline | Slight drop | Comparable |
| Attention score | (shared K) | Latent | |
| RoPE compatibility | Native | Native | Decoupled |
GQA reduces cache by sharing KV heads across query groups. MLA reduces cache more aggressively by projecting to a shared latent. The quality difference is minimal because the up-projection matrices are learned and can reconstruct head-specific information.
4. Impact on Batched Serving
MLA dramatically changes the memory-vs-compute tradeoff in serving:
Memory-bound decoding phase: With MHA, long contexts exhaust GPU HBM due to KV cache. MLA's compression allows:
- Longer context windows (~4× more tokens in same memory)
- Larger batch sizes (more concurrent requests)
- Better prefix caching hit rates (smaller cache entries)
Compute-bound prefill phase: MLA adds decompression overhead, but this is amortized:
- Prefill is already compute-heavy (O(n²) attention)
- The additional matmuls for up-projection are O(n × d_c × d_v) per layer
- Net effect: minor prefill slowdown, massive decoding speedup
Implementation
import torch
import torch.nn as nn
import math
class MultiHeadLatentAttention(nn.Module):
"""
MLA attention layer matching DeepSeek-V2/V3 and Kimi K2.x architecture.
Key features:
- Low-rank KV compression (cache only c_KV latent vector)
- Decoupled RoPE for position-aware attention
- Weight absorption for efficient score computation
"""
def __init__(
self,
d_model: int = 4096,
n_heads: int = 128,
d_k: int = 128,
d_v: int = 128,
d_c: int = 512, # KV latent dimension (compression target)
d_c_prime: int = 1536, # Query latent dimension
d_r: int = 64, # Decoupled RoPE key dimension per head
max_seq_len: int = 8192,
rope_base: float = 10000.0,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
self.d_c = d_c
self.d_c_prime = d_c_prime
self.d_r = d_r
# === Down-projections (compression) ===
self.w_dkv = nn.Linear(d_model, d_c, bias=False) # KV latent
self.w_dq = nn.Linear(d_model, d_c_prime, bias=False) # Q latent
# === Up-projections (decompression) ===
# KV up-projections: latent -> per-head K and V
self.w_uk = nn.Linear(d_c, n_heads * d_k, bias=False)
self.w_uv = nn.Linear(d_c, n_heads * d_v, bias=False)
# Q up-projection: latent -> per-head Q
self.w_uq = nn.Linear(d_c_prime, n_heads * d_k, bias=False)
# === Decoupled RoPE projections ===
self.w_kr = nn.Linear(d_c, n_heads * d_r, bias=False) # Rope key from latent
self.w_qr = nn.Linear(d_c_prime, n_heads * d_r, bias=False) # Rope query from latent
# === Output projection ===
self.w_o = nn.Linear(n_heads * d_v, d_model, bias=False)
# RoPE frequencies
inv_freq = 1.0 / (rope_base ** (torch.arange(0, d_r, 2).float() / d_r))
self.register_buffer('inv_freq', inv_freq)
def _apply_rope(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
"""Apply rotary position embedding to tensor of shape [batch, seq, n_heads, d_r]."""
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq) # [seq, d_r//2]
cos = freqs.cos().unsqueeze(0).unsqueeze(2) # [1, seq, 1, d_r//2]
sin = freqs.sin().unsqueeze(0).unsqueeze(2)
x1, x2 = x[..., ::2], x[..., 1::2]
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos,
], dim=-1).flatten(-2)
return rotated
def forward(
self,
x: torch.Tensor,
kv_cache: torch.Tensor = None,
start_pos: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, d_model]
kv_cache: Cached c_KV from previous tokens [batch, cache_len, d_c]
start_pos: Position offset for RoPE
Returns:
output: [batch, seq_len, d_model]
new_kv_cache: Updated cache [batch, cache_len + seq_len, d_c]
"""
B, S, _ = x.shape
# === Step 1: Compress to latent space ===
c_kv = self.w_dkv(x) # [B, S, d_c] - THIS is what gets cached
c_q = self.w_dq(x) # [B, S, d_c']
# === Step 2: Decompress for attention computation ===
# K, V up-projection from latent
k_content = self.w_uk(c_kv) # [B, S, n_heads * d_k]
v = self.w_uv(c_kv) # [B, S, n_heads * d_v]
q_content = self.w_uq(c_q) # [B, S, n_heads * d_k]
# Reshape to multi-head format
q_content = q_content.view(B, S, self.n_heads, self.d_k)
k_content = k_content.view(B, S, self.n_heads, self.d_k)
v = v.view(B, S, self.n_heads, self.d_v)
# === Step 3: Decoupled RoPE ===
# Project to rope-specific dimensions and apply RoPE
k_rope = self.w_kr(c_kv).view(B, S, self.n_heads, self.d_r)
q_rope = self.w_qr(c_q).view(B, S, self.n_heads, self.d_r)
k_rope = self._apply_rope(k_rope, start_pos + S)
q_rope = self._apply_rope(q_rope, start_pos + S)
# Concatenate content + rope for full key and query
q = torch.cat([q_content, q_rope], dim=-1) # [B, S, n_heads, d_k + d_r]
k = torch.cat([k_content, k_rope], dim=-1) # [B, S, n_heads, d_k + d_r]
# === Step 4: KV cache management ===
if kv_cache is not None:
# Append new latent to cache
new_kv_cache = torch.cat([kv_cache, c_kv], dim=1)
# Decompress full cache for attention
k_cache = self.w_uk(kv_cache).view(B, -1, self.n_heads, self.d_k)
k_cache_rope = self._apply_rope(
self.w_kr(kv_cache).view(B, -1, self.n_heads, self.d_r),
start_pos # cache already has positions 0..start_pos-1
)
k = torch.cat([
torch.cat([k_cache, k_cache_rope], dim=-1),
k
], dim=1)
v_cache = self.w_uv(kv_cache).view(B, -1, self.n_heads, self.d_v)
v = torch.cat([v_cache, v], dim=1)
else:
new_kv_cache = c_kv
# === Step 5: Compute attention ===
# Transpose for attention: [B, n_heads, seq, dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
d_attn = self.d_k + self.d_r
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_attn)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v) # [B, n_heads, S, d_v]
# === Step 6: Output projection ===
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, -1)
output = self.w_o(attn_output)
return output, new_kv_cache
# === Example: Compare MLA vs MHA cache sizes ===
def compare_cache_sizes():
"""Demonstrate the KV cache savings of MLA over MHA."""
n_heads = 128
d_k = 128
d_c = 512 # DeepSeek-V3 latent dim
d_r = 64 # Decoupled rope dim per head
seq_len = 65536 # 64K context
bytes_per_element = 2 # FP16
# MHA: cache K and V for all heads
mha_cache_per_token = 2 * n_heads * d_k # K + V
mha_total = mha_cache_per_token * seq_len * bytes_per_element / (1024**3)
# MLA: cache c_KV latent + decoupled rope keys for all heads
mla_cache_per_token = d_c + n_heads * d_r
mla_total = mla_cache_per_token * seq_len * bytes_per_element / (1024**3)
print(f"MHA KV cache (64K ctx): {mha_total:.2f} GB per layer")
print(f"MLA KV cache (64K ctx): {mla_total:.2f} GB per layer")
print(f"Compression ratio: {mha_cache_per_token / mla_cache_per_token:.1f}x")
print(f"\nFor 60 layers:")
print(f" MHA: {mha_total * 60:.1f} GB")
print(f" MLA: {mla_total * 60:.1f} GB")
print(f" Savings: {(mha_total - mla_total) * 60:.1f} GB")
if __name__ == "__main__":
# Test MLA forward pass
mla = MultiHeadLatentAttention(
d_model=4096, n_heads=8, d_k=64, d_v=64,
d_c=128, d_c_prime=256, d_r=32,
)
x = torch.randn(2, 10, 4096) # batch=2, seq=10
output, cache = mla(x)
print(f"Output shape: {output.shape}") # [2, 10, 4096]
print(f"Cache shape: {cache.shape}") # [2, 10, 128] - only d_c!
# Autoregressive step
x2 = torch.randn(2, 1, 4096)
output2, cache2 = mla(x2, kv_cache=cache, start_pos=10)
print(f"Output2 shape: {output2.shape}") # [2, 1, 4096]
print(f"Cache2 shape: {cache2.shape}") # [2, 11, 128] - grew by 1
print("\n--- Cache Comparison ---")
compare_cache_sizes()
Prerequisites
- kv-cache - You must understand standard KV caching before understanding why MLA compresses it
- paged-attention - MLA changes what gets paged (latent vectors, not KV pairs)
- flash-attention - MLA's absorbed attention can be fused into FlashAttention-style kernels
- attention-mechanism - Foundation for understanding attention computation
References
DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model - Liu et al., 2024. arxiv:2405.04434 - Original MLA paper introducing the latent compression and decoupled RoPE strategy.
DeepSeek-V3 Technical Report - DeepSeek-AI, 2024. arxiv:2412.19437 - Scales MLA to 671B MoE with auxiliary-loss-free routing. Details the multi-token prediction (MTP) that inspired EAGLE-style draft heads.
vLLM MLA Implementation - github.com/vllm-project/vllm - Production MLA kernel with weight absorption and FlashAttention integration.
FlashInfer MLA Attention - github.com/flashinfer-ai/flashinfer - Custom CUDA kernels for MLA that support both prefill and decode phases with batched latent cache.
Top comments (0)