DEV Community

Sirajuddin Shaik
Sirajuddin Shaik

Posted on

Multi-Head Latent Attention (MLA)

Compressing KV cache via low-rank projections — the attention mechanism behind DeepSeek-V2/V3 and Kimi K2.x

Why This Matters

Multi-Head Latent Attention (MLA) is the attention variant that replaces standard Multi-Head Attention (MHA) in DeepSeek-V2, DeepSeek-V3, and Kimi K2.x models. Instead of caching full KV pairs per head, MLA projects them into a low-dimensional latent space, achieving 5-10x KV cache compression with minimal quality loss.

  • MLA changes how prefix caching, chunked prefill, and paged attention must be implemented

Formal Definition

Standard Multi-Head Attention (MHA)

For input XRn×d\mathbf{X} \in \mathbb{R}^{n \times d} , MHA computes per-head projections:

Qh=XWQ(h),Kh=XWK(h),Vh=XWV(h) \mathbf{Q}_h = \mathbf{X} \mathbf{W}_Q^{(h)}, \quad \mathbf{K}_h = \mathbf{X} \mathbf{W}_K^{(h)}, \quad \mathbf{V}_h = \mathbf{X} \mathbf{W}_V^{(h)}

where WQ(h)Rd×dk\mathbf{W}_Q^{(h)} \in \mathbb{R}^{d \times d_k} , WK(h)Rd×dk\mathbf{W}_K^{(h)} \in \mathbb{R}^{d \times d_k} , WV(h)Rd×dv\mathbf{W}_V^{(h)} \in \mathbb{R}^{d \times d_v} .

KV cache size per token: 2×nh×dk2 \times n_h \times d_k elements.

MLA: Low-Rank Latent Projection

MLA replaces the per-head KV projections with a shared low-rank latent compression:

Compression (KV → Latent):

cKV=XWDKVRn×dc \mathbf{c}^{KV} = \mathbf{X} \mathbf{W}_{DKV} \in \mathbb{R}^{n \times d_c}

where WDKVRd×dc\mathbf{W}_{DKV} \in \mathbb{R}^{d \times d_c} is the down-projection matrix and dcnh×dkd_c \ll n_h \times d_k .

Decompression (Latent → KV):

Kh=cKVWUK(h),Vh=cKVWUV(h) \mathbf{K}h = \mathbf{c}^{KV} \mathbf{W}{UK}^{(h)}, \quad \mathbf{V}h = \mathbf{c}^{KV} \mathbf{W}{UV}^{(h)}

where WUK(h)Rdc×dk\mathbf{W}{UK}^{(h)} \in \mathbb{R}^{d_c \times d_k} and WUV(h)Rdc×dv\mathbf{W}{UV}^{(h)} \in \mathbb{R}^{d_c \times d_v} are up-projection matrices.

KV cache per token: Only cKVRdc\mathbf{c}^{KV} \in \mathbb{R}^{d_c} is stored — a single vector of dimension dcd_c .

Compression Ratio

For a model with nhn_h heads and head dimension dkd_k :

Compression Ratio=2nhdkdc \text{Compression Ratio} = \frac{2 \cdot n_h \cdot d_k}{d_c}

In DeepSeek-V3: nh=128n_h = 128 , dk=128d_k = 128 , dc=512d_c = 512 :

2×128×128512=64× compression \frac{2 \times 128 \times 128}{512} = 64 \times \text{ compression}

Query Compression (Optional)

MLA also compresses queries for training efficiency:

cQ=XWDQRn×dc \mathbf{c}^Q = \mathbf{X} \mathbf{W}{DQ} \in \mathbb{R}^{n \times d_c'}

Qh=cQWUQ(h) \mathbf{Q}_h = \mathbf{c}^Q \mathbf{W}{UQ}^{(h)}

This doesn't affect the KV cache but reduces the activation memory during training.

Rotary Position Embedding (RoPE) Handling

RoPE is applied to the decompressed queries and keys. To keep the KV cache small, MLA applies RoPE to a separate "absorbed" key projection:

K^h=RoPE(cKVWKR(h)) \hat{\mathbf{K}}h = \text{RoPE}(\mathbf{c}^{KV} \mathbf{W}{KR}^{(h)})

where WKR(h)Rdc×dr\mathbf{W}_{KR}^{(h)} \in \mathbb{R}^{d_c \times d_r} with drdkd_r \ll d_k is a narrow projection that carries positional information. The cached representation remains cKV\mathbf{c}^{KV} (position-agnostic), and the RoPE key K^h\hat{\mathbf{K}}_h is recomputed at attention time from the cached latent.


Core Concepts

1. Weight Absorption (The Key Trick)

The critical insight in MLA is that the up-projection matrices WUK(h)\mathbf{W}_{UK}^{(h)} can be absorbed into the query projection during attention computation:

Attention(Q,K,V)=softmax(QhKhTdk)Vh \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{\mathbf{Q}_h \mathbf{K}_h^T}{\sqrt{d_k}}\right) \mathbf{V}_h

Substituting the decompressed forms:

QhKhT=(cQWUQ(h))(cKVWUK(h))T=cQ(WUQ(h)WUK(h)T)cKVT \mathbf{Q}h \mathbf{K}_h^T = (\mathbf{c}^Q \mathbf{W}{UQ}^{(h)})(\mathbf{c}^{KV} \mathbf{W}{UK}^{(h)})^T = \mathbf{c}^Q (\mathbf{W}{UQ}^{(h)} {\mathbf{W}_{UK}^{(h)}}^T) {\mathbf{c}^{KV}}^T

If we define Wabsorbed(h)=WUQ(h)WUK(h)TRdc×dc\mathbf{W}{absorbed}^{(h)} = \mathbf{W}{UQ}^{(h)} {\mathbf{W}_{UK}^{(h)}}^T \in \mathbb{R}^{d_c' \times d_c} , then:

QhKhT=cQWabsorbed(h)cKVT \mathbf{Q}h \mathbf{K}_h^T = \mathbf{c}^Q \mathbf{W}{absorbed}^{(h)} {\mathbf{c}^{KV}}^T

This means the attention score can be computed directly from the latent representations, avoiding explicit decompression of K and V for the score computation. However, the V decompression is still needed for the output.

Practical implication: During decoding, we can compute attention scores without materializing the full K matrix. Only V needs decompression after softmax.

2. Decoupled RoPE Strategy

RoPE requires position-dependent keys, which conflicts with caching a position-agnostic latent. MLA solves this with a decoupled key:

  • Content key: Khcontent=cKVWUK(h)\mathbf{K}h^{content} = \mathbf{c}^{KV} \mathbf{W}{UK}^{(h)} — cached in latent form
  • Position key: Khrope=RoPE(cKVWKR(h))\mathbf{K}h^{rope} = \text{RoPE}(\mathbf{c}^{KV} \mathbf{W}{KR}^{(h)}) — small, position-aware, must be cached separately

The attention score becomes:

score(q,k)=QhcontentKhcontentTdk+QhropeKhropeTdr \text{score}(q, k) = \frac{\mathbf{Q}_h^{content} \cdot {\mathbf{K}_h^{content}}^T}{\sqrt{d_k}} + \frac{\mathbf{Q}_h^{rope} \cdot {\mathbf{K}_h^{rope}}^T}{\sqrt{d_r}}

Practical implication: The KV cache stores both cKV\mathbf{c}^{KV} (latent) and Khrope\mathbf{K}_h^{rope} (decoupled rope key). Total cache per token: dc+nh×drd_c + n_h \times d_r .

3. MLA vs GQA vs MHA

Property MHA GQA MLA
KV groups nhn_h nh/gn_h / g 1 (latent)
Cache per token 2nhdk2 n_h d_k 2(nh/g)dk2 (n_h/g) d_k dc+nhdrd_c + n_h d_r
Quality Baseline Slight drop Comparable
Attention score QKTQK^T QKTQK^T (shared K) Latent QKTQK^T
RoPE compatibility Native Native Decoupled

GQA reduces cache by sharing KV heads across query groups. MLA reduces cache more aggressively by projecting to a shared latent. The quality difference is minimal because the up-projection matrices are learned and can reconstruct head-specific information.

4. Impact on Batched Serving

MLA dramatically changes the memory-vs-compute tradeoff in serving:

Memory-bound decoding phase: With MHA, long contexts exhaust GPU HBM due to KV cache. MLA's compression allows:

  • Longer context windows (10x more tokens in same memory)
  • Larger batch sizes (more concurrent requests)
  • Better prefix caching hit rates (smaller cache entries)

Compute-bound prefill phase: MLA adds decompression overhead, but this is amortized:

  • Prefill is already compute-heavy (O(n²) attention)
  • The additional matmuls for up-projection are O(n × d_c × d_v) per layer
  • Net effect: minor prefill slowdown, massive decoding speedup

5. MLA + Speculative Decoding

This is where it gets interesting for Siraj's EAGLE-3 work:

Draft model constraints:

  • The draft model must produce latent KV states compatible with the target model's MLA projections
  • Simply using a smaller MHA model as drafter creates a KV format mismatch
  • EAGLE-3's tree-based speculation must handle the latent→decompressed→verify→latent roundtrip

Verification with MLA:

  1. Draft tokens are generated by the draft model
  2. Target model verifies by running the full MLA attention (decompress latent, compute attention)
  3. Accepted tokens' KV entries must be added to the latent cache ( cKV\mathbf{c}^{KV} ), not the full KV cache
  4. This means the draft model needs to either: (a) predict in latent space, or (b) have its KV outputs projected to latent space

vLLM implementation challenge: vLLM's PagedAttention was designed for MHA/GQA. MLA requires:

  • Modified page table storing latent vectors ( dcd_c ) instead of KV pairs
  • Custom attention kernels for the absorbed + decoupled-RoPE computation
  • Integration with CUDAGraph captures for the decompression path

Implementation

import torch
import torch.nn as nn
import math

class MultiHeadLatentAttention(nn.Module):
    """
    MLA attention layer matching DeepSeek-V2/V3 and Kimi K2.x architecture.

    Key features:
    - Low-rank KV compression (cache only c_KV latent vector)
    - Decoupled RoPE for position-aware attention
    - Weight absorption for efficient score computation
    """

    def __init__(
        self,
        d_model: int = 4096,
        n_heads: int = 128,
        d_k: int = 128,
        d_v: int = 128,
        d_c: int = 512,       # KV latent dimension (compression target)
        d_c_prime: int = 1536, # Query latent dimension
        d_r: int = 64,         # Decoupled RoPE key dimension per head
        max_seq_len: int = 8192,
        rope_base: float = 10000.0,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.d_c = d_c
        self.d_c_prime = d_c_prime
        self.d_r = d_r

        # === Down-projections (compression) ===
        self.w_dkv = nn.Linear(d_model, d_c, bias=False)  # KV latent
        self.w_dq = nn.Linear(d_model, d_c_prime, bias=False)  # Q latent

        # === Up-projections (decompression) ===
        # KV up-projections: latent -> per-head K and V
        self.w_uk = nn.Linear(d_c, n_heads * d_k, bias=False)
        self.w_uv = nn.Linear(d_c, n_heads * d_v, bias=False)

        # Q up-projection: latent -> per-head Q
        self.w_uq = nn.Linear(d_c_prime, n_heads * d_k, bias=False)

        # === Decoupled RoPE projections ===
        self.w_kr = nn.Linear(d_c, n_heads * d_r, bias=False)  # Rope key from latent
        self.w_qr = nn.Linear(d_c_prime, n_heads * d_r, bias=False)  # Rope query from latent

        # === Output projection ===
        self.w_o = nn.Linear(n_heads * d_v, d_model, bias=False)

        # RoPE frequencies
        inv_freq = 1.0 / (rope_base ** (torch.arange(0, d_r, 2).float() / d_r))
        self.register_buffer('inv_freq', inv_freq)

    def _apply_rope(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
        """Apply rotary position embedding to tensor of shape [batch, seq, n_heads, d_r]."""
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)  # [seq, d_r//2]
        cos = freqs.cos().unsqueeze(0).unsqueeze(2)  # [1, seq, 1, d_r//2]
        sin = freqs.sin().unsqueeze(0).unsqueeze(2)

        x1, x2 = x[..., ::2], x[..., 1::2]
        rotated = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos,
        ], dim=-1).flatten(-2)
        return rotated

    def forward(
        self,
        x: torch.Tensor,
        kv_cache: torch.Tensor = None,
        start_pos: int = 0,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input tensor [batch, seq_len, d_model]
            kv_cache: Cached c_KV from previous tokens [batch, cache_len, d_c]
            start_pos: Position offset for RoPE

        Returns:
            output: [batch, seq_len, d_model]
            new_kv_cache: Updated cache [batch, cache_len + seq_len, d_c]
        """
        B, S, _ = x.shape

        # === Step 1: Compress to latent space ===
        c_kv = self.w_dkv(x)       # [B, S, d_c] — THIS is what gets cached
        c_q = self.w_dq(x)          # [B, S, d_c']

        # === Step 2: Decompress for attention computation ===
        # K, V up-projection from latent
        k_content = self.w_uk(c_kv)  # [B, S, n_heads * d_k]
        v = self.w_uv(c_kv)          # [B, S, n_heads * d_v]
        q_content = self.w_uq(c_q)   # [B, S, n_heads * d_k]

        # Reshape to multi-head format
        q_content = q_content.view(B, S, self.n_heads, self.d_k)
        k_content = k_content.view(B, S, self.n_heads, self.d_k)
        v = v.view(B, S, self.n_heads, self.d_v)

        # === Step 3: Decoupled RoPE ===
        # Project to rope-specific dimensions and apply RoPE
        k_rope = self.w_kr(c_kv).view(B, S, self.n_heads, self.d_r)
        q_rope = self.w_qr(c_q).view(B, S, self.n_heads, self.d_r)

        k_rope = self._apply_rope(k_rope, start_pos + S)
        q_rope = self._apply_rope(q_rope, start_pos + S)

        # Concatenate content + rope for full key and query
        q = torch.cat([q_content, q_rope], dim=-1)  # [B, S, n_heads, d_k + d_r]
        k = torch.cat([k_content, k_rope], dim=-1)  # [B, S, n_heads, d_k + d_r]

        # === Step 4: KV cache management ===
        if kv_cache is not None:
            # Append new latent to cache
            new_kv_cache = torch.cat([kv_cache, c_kv], dim=1)
            # Decompress full cache for attention
            k_cache = self.w_uk(kv_cache).view(B, -1, self.n_heads, self.d_k)
            k_cache_rope = self._apply_rope(
                self.w_kr(kv_cache).view(B, -1, self.n_heads, self.d_r),
                start_pos  # cache already has positions 0..start_pos-1
            )
            k = torch.cat([
                torch.cat([k_cache, k_cache_rope], dim=-1),
                k
            ], dim=1)
            v_cache = self.w_uv(kv_cache).view(B, -1, self.n_heads, self.d_v)
            v = torch.cat([v_cache, v], dim=1)
        else:
            new_kv_cache = c_kv

        # === Step 5: Compute attention ===
        # Transpose for attention: [B, n_heads, seq, dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        d_attn = self.d_k + self.d_r
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_attn)
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, v)  # [B, n_heads, S, d_v]

        # === Step 6: Output projection ===
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, -1)
        output = self.w_o(attn_output)

        return output, new_kv_cache


# === Example: Compare MLA vs MHA cache sizes ===
def compare_cache_sizes():
    """Demonstrate the KV cache savings of MLA over MHA."""
    n_heads = 128
    d_k = 128
    d_c = 512    # DeepSeek-V3 latent dim
    d_r = 64     # Decoupled rope dim
    seq_len = 65536  # 64K context
    bytes_per_element = 2  # FP16

    # MHA: cache K and V for all heads
    mha_cache_per_token = 2 * n_heads * d_k  # K + V
    mha_total = mha_cache_per_token * seq_len * bytes_per_element / (1024**3)

    # MLA: cache only c_KV + decoupled rope keys
    mla_cache_per_token = d_c + n_heads * d_r  # latent + rope keys
    mla_total = mla_cache_per_token * seq_len * bytes_per_element / (1024**3)

    print(f"MHA KV cache (64K ctx): {mha_total:.2f} GB per layer")
    print(f"MLA KV cache (64K ctx): {mla_total:.2f} GB per layer")
    print(f"Compression ratio: {mha_cache_per_token / mla_cache_per_token:.1f}x")
    print(f"\nFor 60 layers:")
    print(f"  MHA: {mha_total * 60:.1f} GB")
    print(f"  MLA: {mla_total * 60:.1f} GB")
    print(f"  Savings: {(mha_total - mla_total) * 60:.1f} GB")


if __name__ == "__main__":
    # Test MLA forward pass
    mla = MultiHeadLatentAttention(
        d_model=4096, n_heads=8, d_k=64, d_v=64,
        d_c=128, d_c_prime=256, d_r=32,
    )

    x = torch.randn(2, 10, 4096)  # batch=2, seq=10
    output, cache = mla(x)
    print(f"Output shape: {output.shape}")       # [2, 10, 4096]
    print(f"Cache shape: {cache.shape}")          # [2, 10, 128] — only d_c!

    # Autoregressive step
    x2 = torch.randn(2, 1, 4096)
    output2, cache2 = mla(x2, kv_cache=cache, start_pos=10)
    print(f"Output2 shape: {output2.shape}")      # [2, 1, 4096]
    print(f"Cache2 shape: {cache2.shape}")         # [2, 11, 128] — grew by 1

    print("\n--- Cache Comparison ---")
    compare_cache_sizes()
Enter fullscreen mode Exit fullscreen mode

Connections

Prerequisites

  • kv-cache — You must understand standard KV caching before understanding why MLA compresses it
  • paged-attention — MLA changes what gets paged (latent vectors, not KV pairs)
  • flash-attention — MLA's absorbed attention can be fused into FlashAttention-style kernels
  • attention-mechanism — Foundation for understanding attention computation

Directly Related

  • kimi-k2-6 — Kimi K2.6 uses MLA + MoE, the target model for Siraj's spec-coder project
  • mha2mla-conversion — Techniques for converting MHA models to MLA
  • ktransformers — CPU/GPU hybrid MoE inference that must handle MLA's latent cache
  • arkv-adaptive-kv-cache — Adaptive KV cache management (MLA enables further compression)
  • oaken-hybrid-kv-cache — Online-offline hybrid quantization for KV cache, works with MLA
  • pikv-moe-kv-cache — KV cache management specifically for MoE + MLA architectures

Next Steps

  • sglang — SGLang's RadixAttention implementation with MLA support
  • vllm-omni-disaggregated-serving — Disaggregated serving architectures that benefit from MLA's cache savings
  • ragged-paged-attention-tpu — TPU kernels that can be adapted for MLA's non-standard attention pattern

References

  1. DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — Liu et al., 2024. arxiv:2405.04434 — Original MLA paper introducing the latent compression and decoupled RoPE strategy.

  2. DeepSeek-V3 Technical Report — DeepSeek-AI, 2024. arxiv:2412.19437 — Scales MLA to 671B MoE with auxiliary-loss-free routing. Details the multi-token prediction (MTP) that inspired EAGLE-style draft heads.

  3. vLLM MLA Implementationgithub.com/vllm-project/vllm — Production MLA kernel with weight absorption and FlashAttention integration.

  4. FlashInfer MLA Attentiongithub.com/flashinfer-ai/flashinfer — Custom CUDA kernels for MLA that support both prefill and decode phases with batched latent cache.

Top comments (0)