Compressing KV cache via low-rank projections — the attention mechanism behind DeepSeek-V2/V3 and Kimi K2.x
Why This Matters
Multi-Head Latent Attention (MLA) is the attention variant that replaces standard Multi-Head Attention (MHA) in DeepSeek-V2, DeepSeek-V3, and Kimi K2.x models. Instead of caching full KV pairs per head, MLA projects them into a low-dimensional latent space, achieving 5-10x KV cache compression with minimal quality loss.
- MLA changes how prefix caching, chunked prefill, and paged attention must be implemented
Formal Definition
Standard Multi-Head Attention (MHA)
For input , MHA computes per-head projections:
where , , .
KV cache size per token: elements.
MLA: Low-Rank Latent Projection
MLA replaces the per-head KV projections with a shared low-rank latent compression:
Compression (KV → Latent):
where is the down-projection matrix and .
Decompression (Latent → KV):
where and are up-projection matrices.
KV cache per token: Only is stored — a single vector of dimension .
Compression Ratio
For a model with heads and head dimension :
In DeepSeek-V3:
,
,
:
Query Compression (Optional)
MLA also compresses queries for training efficiency:
This doesn't affect the KV cache but reduces the activation memory during training.
Rotary Position Embedding (RoPE) Handling
RoPE is applied to the decompressed queries and keys. To keep the KV cache small, MLA applies RoPE to a separate "absorbed" key projection:
where with is a narrow projection that carries positional information. The cached representation remains (position-agnostic), and the RoPE key is recomputed at attention time from the cached latent.
Core Concepts
1. Weight Absorption (The Key Trick)
The critical insight in MLA is that the up-projection matrices can be absorbed into the query projection during attention computation:
Substituting the decompressed forms:
If we define , then:
This means the attention score can be computed directly from the latent representations, avoiding explicit decompression of K and V for the score computation. However, the V decompression is still needed for the output.
Practical implication: During decoding, we can compute attention scores without materializing the full K matrix. Only V needs decompression after softmax.
2. Decoupled RoPE Strategy
RoPE requires position-dependent keys, which conflicts with caching a position-agnostic latent. MLA solves this with a decoupled key:
- Content key: — cached in latent form
- Position key: — small, position-aware, must be cached separately
The attention score becomes:
Practical implication: The KV cache stores both (latent) and (decoupled rope key). Total cache per token: .
3. MLA vs GQA vs MHA
| Property | MHA | GQA | MLA |
|---|---|---|---|
| KV groups | 1 (latent) | ||
| Cache per token | |||
| Quality | Baseline | Slight drop | Comparable |
| Attention score | (shared K) | Latent | |
| RoPE compatibility | Native | Native | Decoupled |
GQA reduces cache by sharing KV heads across query groups. MLA reduces cache more aggressively by projecting to a shared latent. The quality difference is minimal because the up-projection matrices are learned and can reconstruct head-specific information.
4. Impact on Batched Serving
MLA dramatically changes the memory-vs-compute tradeoff in serving:
Memory-bound decoding phase: With MHA, long contexts exhaust GPU HBM due to KV cache. MLA's compression allows:
- Longer context windows (10x more tokens in same memory)
- Larger batch sizes (more concurrent requests)
- Better prefix caching hit rates (smaller cache entries)
Compute-bound prefill phase: MLA adds decompression overhead, but this is amortized:
- Prefill is already compute-heavy (O(n²) attention)
- The additional matmuls for up-projection are O(n × d_c × d_v) per layer
- Net effect: minor prefill slowdown, massive decoding speedup
5. MLA + Speculative Decoding
This is where it gets interesting for Siraj's EAGLE-3 work:
Draft model constraints:
- The draft model must produce latent KV states compatible with the target model's MLA projections
- Simply using a smaller MHA model as drafter creates a KV format mismatch
- EAGLE-3's tree-based speculation must handle the latent→decompressed→verify→latent roundtrip
Verification with MLA:
- Draft tokens are generated by the draft model
- Target model verifies by running the full MLA attention (decompress latent, compute attention)
- Accepted tokens' KV entries must be added to the latent cache ( ), not the full KV cache
- This means the draft model needs to either: (a) predict in latent space, or (b) have its KV outputs projected to latent space
vLLM implementation challenge: vLLM's PagedAttention was designed for MHA/GQA. MLA requires:
- Modified page table storing latent vectors ( ) instead of KV pairs
- Custom attention kernels for the absorbed + decoupled-RoPE computation
- Integration with CUDAGraph captures for the decompression path
Implementation
import torch
import torch.nn as nn
import math
class MultiHeadLatentAttention(nn.Module):
"""
MLA attention layer matching DeepSeek-V2/V3 and Kimi K2.x architecture.
Key features:
- Low-rank KV compression (cache only c_KV latent vector)
- Decoupled RoPE for position-aware attention
- Weight absorption for efficient score computation
"""
def __init__(
self,
d_model: int = 4096,
n_heads: int = 128,
d_k: int = 128,
d_v: int = 128,
d_c: int = 512, # KV latent dimension (compression target)
d_c_prime: int = 1536, # Query latent dimension
d_r: int = 64, # Decoupled RoPE key dimension per head
max_seq_len: int = 8192,
rope_base: float = 10000.0,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
self.d_c = d_c
self.d_c_prime = d_c_prime
self.d_r = d_r
# === Down-projections (compression) ===
self.w_dkv = nn.Linear(d_model, d_c, bias=False) # KV latent
self.w_dq = nn.Linear(d_model, d_c_prime, bias=False) # Q latent
# === Up-projections (decompression) ===
# KV up-projections: latent -> per-head K and V
self.w_uk = nn.Linear(d_c, n_heads * d_k, bias=False)
self.w_uv = nn.Linear(d_c, n_heads * d_v, bias=False)
# Q up-projection: latent -> per-head Q
self.w_uq = nn.Linear(d_c_prime, n_heads * d_k, bias=False)
# === Decoupled RoPE projections ===
self.w_kr = nn.Linear(d_c, n_heads * d_r, bias=False) # Rope key from latent
self.w_qr = nn.Linear(d_c_prime, n_heads * d_r, bias=False) # Rope query from latent
# === Output projection ===
self.w_o = nn.Linear(n_heads * d_v, d_model, bias=False)
# RoPE frequencies
inv_freq = 1.0 / (rope_base ** (torch.arange(0, d_r, 2).float() / d_r))
self.register_buffer('inv_freq', inv_freq)
def _apply_rope(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
"""Apply rotary position embedding to tensor of shape [batch, seq, n_heads, d_r]."""
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq) # [seq, d_r//2]
cos = freqs.cos().unsqueeze(0).unsqueeze(2) # [1, seq, 1, d_r//2]
sin = freqs.sin().unsqueeze(0).unsqueeze(2)
x1, x2 = x[..., ::2], x[..., 1::2]
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos,
], dim=-1).flatten(-2)
return rotated
def forward(
self,
x: torch.Tensor,
kv_cache: torch.Tensor = None,
start_pos: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, d_model]
kv_cache: Cached c_KV from previous tokens [batch, cache_len, d_c]
start_pos: Position offset for RoPE
Returns:
output: [batch, seq_len, d_model]
new_kv_cache: Updated cache [batch, cache_len + seq_len, d_c]
"""
B, S, _ = x.shape
# === Step 1: Compress to latent space ===
c_kv = self.w_dkv(x) # [B, S, d_c] — THIS is what gets cached
c_q = self.w_dq(x) # [B, S, d_c']
# === Step 2: Decompress for attention computation ===
# K, V up-projection from latent
k_content = self.w_uk(c_kv) # [B, S, n_heads * d_k]
v = self.w_uv(c_kv) # [B, S, n_heads * d_v]
q_content = self.w_uq(c_q) # [B, S, n_heads * d_k]
# Reshape to multi-head format
q_content = q_content.view(B, S, self.n_heads, self.d_k)
k_content = k_content.view(B, S, self.n_heads, self.d_k)
v = v.view(B, S, self.n_heads, self.d_v)
# === Step 3: Decoupled RoPE ===
# Project to rope-specific dimensions and apply RoPE
k_rope = self.w_kr(c_kv).view(B, S, self.n_heads, self.d_r)
q_rope = self.w_qr(c_q).view(B, S, self.n_heads, self.d_r)
k_rope = self._apply_rope(k_rope, start_pos + S)
q_rope = self._apply_rope(q_rope, start_pos + S)
# Concatenate content + rope for full key and query
q = torch.cat([q_content, q_rope], dim=-1) # [B, S, n_heads, d_k + d_r]
k = torch.cat([k_content, k_rope], dim=-1) # [B, S, n_heads, d_k + d_r]
# === Step 4: KV cache management ===
if kv_cache is not None:
# Append new latent to cache
new_kv_cache = torch.cat([kv_cache, c_kv], dim=1)
# Decompress full cache for attention
k_cache = self.w_uk(kv_cache).view(B, -1, self.n_heads, self.d_k)
k_cache_rope = self._apply_rope(
self.w_kr(kv_cache).view(B, -1, self.n_heads, self.d_r),
start_pos # cache already has positions 0..start_pos-1
)
k = torch.cat([
torch.cat([k_cache, k_cache_rope], dim=-1),
k
], dim=1)
v_cache = self.w_uv(kv_cache).view(B, -1, self.n_heads, self.d_v)
v = torch.cat([v_cache, v], dim=1)
else:
new_kv_cache = c_kv
# === Step 5: Compute attention ===
# Transpose for attention: [B, n_heads, seq, dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
d_attn = self.d_k + self.d_r
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_attn)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v) # [B, n_heads, S, d_v]
# === Step 6: Output projection ===
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, -1)
output = self.w_o(attn_output)
return output, new_kv_cache
# === Example: Compare MLA vs MHA cache sizes ===
def compare_cache_sizes():
"""Demonstrate the KV cache savings of MLA over MHA."""
n_heads = 128
d_k = 128
d_c = 512 # DeepSeek-V3 latent dim
d_r = 64 # Decoupled rope dim
seq_len = 65536 # 64K context
bytes_per_element = 2 # FP16
# MHA: cache K and V for all heads
mha_cache_per_token = 2 * n_heads * d_k # K + V
mha_total = mha_cache_per_token * seq_len * bytes_per_element / (1024**3)
# MLA: cache only c_KV + decoupled rope keys
mla_cache_per_token = d_c + n_heads * d_r # latent + rope keys
mla_total = mla_cache_per_token * seq_len * bytes_per_element / (1024**3)
print(f"MHA KV cache (64K ctx): {mha_total:.2f} GB per layer")
print(f"MLA KV cache (64K ctx): {mla_total:.2f} GB per layer")
print(f"Compression ratio: {mha_cache_per_token / mla_cache_per_token:.1f}x")
print(f"\nFor 60 layers:")
print(f" MHA: {mha_total * 60:.1f} GB")
print(f" MLA: {mla_total * 60:.1f} GB")
print(f" Savings: {(mha_total - mla_total) * 60:.1f} GB")
if __name__ == "__main__":
# Test MLA forward pass
mla = MultiHeadLatentAttention(
d_model=4096, n_heads=8, d_k=64, d_v=64,
d_c=128, d_c_prime=256, d_r=32,
)
x = torch.randn(2, 10, 4096) # batch=2, seq=10
output, cache = mla(x)
print(f"Output shape: {output.shape}") # [2, 10, 4096]
print(f"Cache shape: {cache.shape}") # [2, 10, 128] — only d_c!
# Autoregressive step
x2 = torch.randn(2, 1, 4096)
output2, cache2 = mla(x2, kv_cache=cache, start_pos=10)
print(f"Output2 shape: {output2.shape}") # [2, 1, 4096]
print(f"Cache2 shape: {cache2.shape}") # [2, 11, 128] — grew by 1
print("\n--- Cache Comparison ---")
compare_cache_sizes()
Connections
Prerequisites
- kv-cache — You must understand standard KV caching before understanding why MLA compresses it
- paged-attention — MLA changes what gets paged (latent vectors, not KV pairs)
- flash-attention — MLA's absorbed attention can be fused into FlashAttention-style kernels
- attention-mechanism — Foundation for understanding attention computation
Directly Related
- kimi-k2-6 — Kimi K2.6 uses MLA + MoE, the target model for Siraj's spec-coder project
- mha2mla-conversion — Techniques for converting MHA models to MLA
- ktransformers — CPU/GPU hybrid MoE inference that must handle MLA's latent cache
- arkv-adaptive-kv-cache — Adaptive KV cache management (MLA enables further compression)
- oaken-hybrid-kv-cache — Online-offline hybrid quantization for KV cache, works with MLA
- pikv-moe-kv-cache — KV cache management specifically for MoE + MLA architectures
Next Steps
- sglang — SGLang's RadixAttention implementation with MLA support
- vllm-omni-disaggregated-serving — Disaggregated serving architectures that benefit from MLA's cache savings
- ragged-paged-attention-tpu — TPU kernels that can be adapted for MLA's non-standard attention pattern
References
DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — Liu et al., 2024. arxiv:2405.04434 — Original MLA paper introducing the latent compression and decoupled RoPE strategy.
DeepSeek-V3 Technical Report — DeepSeek-AI, 2024. arxiv:2412.19437 — Scales MLA to 671B MoE with auxiliary-loss-free routing. Details the multi-token prediction (MTP) that inspired EAGLE-style draft heads.
vLLM MLA Implementation — github.com/vllm-project/vllm — Production MLA kernel with weight absorption and FlashAttention integration.
FlashInfer MLA Attention — github.com/flashinfer-ai/flashinfer — Custom CUDA kernels for MLA that support both prefill and decode phases with batched latent cache.
Top comments (0)