DEV Community

Sirajuddin Shaik
Sirajuddin Shaik

Posted on

Multi-Head Latent Attention (MLA)

Compressing KV cache via low-rank projections - the attention mechanism behind DeepSeek-V2/V3 and Kimi K2.x

Why This Matters

Multi-Head Latent Attention (MLA) is the attention variant that replaces standard Multi-Head Attention (MHA) in DeepSeek-V2, DeepSeek-V3, and Kimi K2.x models. Instead of caching full KV pairs per head, MLA projects them into a low-dimensional latent space, achieving 3-5× KV cache compression with minimal quality loss.

  • MLA changes how prefix caching, chunked prefill, and paged attention must be implemented

Formal Definition

Standard Multi-Head Attention (MHA)

For input XRn×d\mathbf{X} \in \mathbb{R}^{n \times d} , MHA computes per-head projections:

Qh=XWQ(h),Kh=XWK(h),Vh=XWV(h) \mathbf{Q}_h = \mathbf{X} \mathbf{W}_Q^{(h)}, \quad \mathbf{K}_h = \mathbf{X} \mathbf{W}_K^{(h)}, \quad \mathbf{V}_h = \mathbf{X} \mathbf{W}_V^{(h)}

where WQ(h)Rd×dk\mathbf{W}_Q^{(h)} \in \mathbb{R}^{d \times d_k} , WK(h)Rd×dk\mathbf{W}_K^{(h)} \in \mathbb{R}^{d \times d_k} , WV(h)Rd×dv\mathbf{W}_V^{(h)} \in \mathbb{R}^{d \times d_v} .

KV cache size per token: 2×nh×dk2 \times n_h \times d_k elements.

MLA: Low-Rank Latent Projection

MLA replaces the per-head KV projections with a shared low-rank latent compression:

Compression (KV → Latent):

cKV=XWDKVRn×dc \mathbf{c}^{KV} = \mathbf{X} \mathbf{W}_{DKV} \in \mathbb{R}^{n \times d_c}

where WDKVRd×dc\mathbf{W}_{DKV} \in \mathbb{R}^{d \times d_c} is the down-projection matrix and dcnh×dkd_c \ll n_h \times d_k .

Decompression (Latent → KV):

Kh=cKVWUK(h),Vh=cKVWUV(h) \mathbf{K}h = \mathbf{c}^{KV} \mathbf{W}{UK}^{(h)}, \quad \mathbf{V}h = \mathbf{c}^{KV} \mathbf{W}{UV}^{(h)}

where WUK(h)Rdc×dk\mathbf{W}{UK}^{(h)} \in \mathbb{R}^{d_c \times d_k} and WUV(h)Rdc×dv\mathbf{W}{UV}^{(h)} \in \mathbb{R}^{d_c \times d_v} are up-projection matrices.

KV cache per token: Only cKVRdc\mathbf{c}^{KV} \in \mathbb{R}^{d_c} is stored - a single vector of dimension dcd_c .

Query Compression (Optional)

MLA also compresses queries for training efficiency:

cQ=XWDQRn×dc \mathbf{c}^Q = \mathbf{X} \mathbf{W}{DQ} \in \mathbb{R}^{n \times d_c'}

Qh=cQWUQ(h) \mathbf{Q}_h = \mathbf{c}^Q \mathbf{W}{UQ}^{(h)}

This doesn't affect the KV cache but reduces the activation memory during training.

Rotary Position Embedding (RoPE) Handling

RoPE is applied to the decompressed queries and keys. To keep the KV cache small, MLA applies RoPE to a separate "absorbed" key projection:

K^h=RoPE(cKVWKR(h)) \hat{\mathbf{K}}h = \text{RoPE}(\mathbf{c}^{KV} \mathbf{W}{KR}^{(h)})

where WKR(h)Rdc×dr\mathbf{W}_{KR}^{(h)} \in \mathbb{R}^{d_c \times d_r} with drdkd_r \ll d_k is a narrow projection that carries positional information. The cached representation remains cKV\mathbf{c}^{KV} (position-agnostic), and the RoPE key K^h\hat{\mathbf{K}}_h is recomputed at attention time from the cached latent.


Compression Ratio

For a model with nhn_h heads and head dimension dkd_k :

The total MLA KV cache per token stores:

  1. The latent vector cKV\mathbf{c}^{KV} of dimension dcd_c
  2. The decoupled RoPE keys for all heads: nh×drn_h \times d_r
Compression Ratio=2nhdkdc+nhdr \text{Compression Ratio} = \frac{2 \cdot n_h \cdot d_k}{d_c + n_h \cdot d_r}

Approximating DeepSeek-V3 dimensions (see paper Table 1 for exact values):
nh=128n_h = 128 , dk=128d_k = 128 , dc=512d_c = 512 , dr=64d_r = 64 :

2×128×128512+128×64=32,7688,7043.76× compression \frac{2 \times 128 \times 128}{512 + 128 \times 64} = \frac{32{,}768}{8{,}704} \approx 3.76 \times \text{ compression}

Real-world impact: With 3.76× less cache per layer, a 60-layer model at 64K context drops from ~1TB to ~266GB of KV cache. That's the difference between needing a DGX H200 cluster versus a single 8×H100 node.

Note: Reported "5-10×" compression figures in literature often come from models with smaller drd_r (e.g., 32 or fewer) or measure only the cKV\mathbf{c}^{KV} component. The practical ratio depends on the specific dimension choices.


Core Concepts

1. Weight Absorption (The Key Trick)

The critical insight in MLA is that the up-projection matrices WUK(h)\mathbf{W}_{UK}^{(h)} can be absorbed into the query projection during attention computation:

Attention(Q,K,V)=softmax(QhKhTdk)Vh \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{\mathbf{Q}_h \mathbf{K}_h^T}{\sqrt{d_k}}\right) \mathbf{V}_h

Substituting the decompressed forms:

QhKhT=(cQWUQ(h))(cKVWUK(h))T=cQ(WUQ(h)WUK(h)T)cKVT \mathbf{Q}h \mathbf{K}_h^T = (\mathbf{c}^Q \mathbf{W}{UQ}^{(h)})(\mathbf{c}^{KV} \mathbf{W}{UK}^{(h)})^T = \mathbf{c}^Q (\mathbf{W}{UQ}^{(h)} {\mathbf{W}_{UK}^{(h)}}^T) {\mathbf{c}^{KV}}^T

If we define Wabsorbed(h)=WUQ(h)WUK(h)TRdc×dc\mathbf{W}{absorbed}^{(h)} = \mathbf{W}{UQ}^{(h)} {\mathbf{W}_{UK}^{(h)}}^T \in \mathbb{R}^{d_c' \times d_c} , then:

QhKhT=cQWabsorbed(h)cKVT \mathbf{Q}h \mathbf{K}_h^T = \mathbf{c}^Q \mathbf{W}{absorbed}^{(h)} {\mathbf{c}^{KV}}^T

This means the attention score can be computed directly from the latent representations, avoiding explicit decompression of K and V for the score computation. However, the V decompression is still needed for the output.

Practical implication: During decoding, we can compute attention scores without materializing the full K matrix. Only V needs decompression after softmax.

2. Decoupled RoPE Strategy

RoPE requires position-dependent keys, which conflicts with caching a position-agnostic latent. MLA solves this with a decoupled key:

  • Content key: Khcontent=cKVWUK(h)\mathbf{K}h^{content} = \mathbf{c}^{KV} \mathbf{W}{UK}^{(h)} - cached in latent form
  • Position key: Khrope=RoPE(cKVWKR(h))\mathbf{K}h^{rope} = \text{RoPE}(\mathbf{c}^{KV} \mathbf{W}{KR}^{(h)}) - small, position-aware, must be cached separately

The attention score becomes:

score(q,k)=QhcontentKhcontentTdk+QhropeKhropeTdr \text{score}(q, k) = \frac{\mathbf{Q}_h^{content} \cdot {\mathbf{K}_h^{content}}^T}{\sqrt{d_k}} + \frac{\mathbf{Q}_h^{rope} \cdot {\mathbf{K}_h^{rope}}^T}{\sqrt{d_r}}

Practical implication: The KV cache stores both cKV\mathbf{c}^{KV} (latent) and Khrope\mathbf{K}_h^{rope} (decoupled rope key). Total cache per token: dc+nh×drd_c + n_h \times d_r .

3. MLA vs GQA vs MHA

Property MHA GQA MLA
KV groups nhn_h nh/gn_h / g 1 (latent)
Cache per token 2nhdk2 n_h d_k 2(nh/g)dk2 (n_h/g) d_k dc+nhdrd_c + n_h d_r
Quality Baseline Slight drop Comparable
Attention score QKTQK^T QKTQK^T (shared K) Latent QKTQK^T
RoPE compatibility Native Native Decoupled

GQA reduces cache by sharing KV heads across query groups. MLA reduces cache more aggressively by projecting to a shared latent. The quality difference is minimal because the up-projection matrices are learned and can reconstruct head-specific information.

4. Impact on Batched Serving

MLA dramatically changes the memory-vs-compute tradeoff in serving:

Memory-bound decoding phase: With MHA, long contexts exhaust GPU HBM due to KV cache. MLA's compression allows:

  • Longer context windows (~4× more tokens in same memory)
  • Larger batch sizes (more concurrent requests)
  • Better prefix caching hit rates (smaller cache entries)

Compute-bound prefill phase: MLA adds decompression overhead, but this is amortized:

  • Prefill is already compute-heavy (O(n²) attention)
  • The additional matmuls for up-projection are O(n × d_c × d_v) per layer
  • Net effect: minor prefill slowdown, massive decoding speedup

Implementation

import torch
import torch.nn as nn
import math

class MultiHeadLatentAttention(nn.Module):
    """
    MLA attention layer matching DeepSeek-V2/V3 and Kimi K2.x architecture.

    Key features:
    - Low-rank KV compression (cache only c_KV latent vector)
    - Decoupled RoPE for position-aware attention
    - Weight absorption for efficient score computation
    """

    def __init__(
        self,
        d_model: int = 4096,
        n_heads: int = 128,
        d_k: int = 128,
        d_v: int = 128,
        d_c: int = 512,          # KV latent dimension (compression target)
        d_c_prime: int = 1536,   # Query latent dimension
        d_r: int = 64,           # Decoupled RoPE key dimension per head
        max_seq_len: int = 8192,
        rope_base: float = 10000.0,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.d_c = d_c
        self.d_c_prime = d_c_prime
        self.d_r = d_r

        # === Down-projections (compression) ===
        self.w_dkv = nn.Linear(d_model, d_c, bias=False)        # KV latent
        self.w_dq = nn.Linear(d_model, d_c_prime, bias=False)   # Q latent

        # === Up-projections (decompression) ===
        # KV up-projections: latent -> per-head K and V
        self.w_uk = nn.Linear(d_c, n_heads * d_k, bias=False)
        self.w_uv = nn.Linear(d_c, n_heads * d_v, bias=False)

        # Q up-projection: latent -> per-head Q
        self.w_uq = nn.Linear(d_c_prime, n_heads * d_k, bias=False)

        # === Decoupled RoPE projections ===
        self.w_kr = nn.Linear(d_c, n_heads * d_r, bias=False)   # Rope key from latent
        self.w_qr = nn.Linear(d_c_prime, n_heads * d_r, bias=False)  # Rope query from latent

        # === Output projection ===
        self.w_o = nn.Linear(n_heads * d_v, d_model, bias=False)

        # RoPE frequencies
        inv_freq = 1.0 / (rope_base ** (torch.arange(0, d_r, 2).float() / d_r))
        self.register_buffer('inv_freq', inv_freq)

    def _apply_rope(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
        """Apply rotary position embedding to tensor of shape [batch, seq, n_heads, d_r]."""
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)              # [seq, d_r//2]
        cos = freqs.cos().unsqueeze(0).unsqueeze(2)        # [1, seq, 1, d_r//2]
        sin = freqs.sin().unsqueeze(0).unsqueeze(2)

        x1, x2 = x[..., ::2], x[..., 1::2]
        rotated = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos,
        ], dim=-1).flatten(-2)
        return rotated

    def forward(
        self,
        x: torch.Tensor,
        kv_cache: torch.Tensor = None,
        start_pos: int = 0,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input tensor [batch, seq_len, d_model]
            kv_cache: Cached c_KV from previous tokens [batch, cache_len, d_c]
            start_pos: Position offset for RoPE

        Returns:
            output: [batch, seq_len, d_model]
            new_kv_cache: Updated cache [batch, cache_len + seq_len, d_c]
        """
        B, S, _ = x.shape

        # === Step 1: Compress to latent space ===
        c_kv = self.w_dkv(x)     # [B, S, d_c] - THIS is what gets cached
        c_q = self.w_dq(x)       # [B, S, d_c']

        # === Step 2: Decompress for attention computation ===
        # K, V up-projection from latent
        k_content = self.w_uk(c_kv)    # [B, S, n_heads * d_k]
        v = self.w_uv(c_kv)            # [B, S, n_heads * d_v]
        q_content = self.w_uq(c_q)     # [B, S, n_heads * d_k]

        # Reshape to multi-head format
        q_content = q_content.view(B, S, self.n_heads, self.d_k)
        k_content = k_content.view(B, S, self.n_heads, self.d_k)
        v = v.view(B, S, self.n_heads, self.d_v)

        # === Step 3: Decoupled RoPE ===
        # Project to rope-specific dimensions and apply RoPE
        k_rope = self.w_kr(c_kv).view(B, S, self.n_heads, self.d_r)
        q_rope = self.w_qr(c_q).view(B, S, self.n_heads, self.d_r)

        k_rope = self._apply_rope(k_rope, start_pos + S)
        q_rope = self._apply_rope(q_rope, start_pos + S)

        # Concatenate content + rope for full key and query
        q = torch.cat([q_content, q_rope], dim=-1)   # [B, S, n_heads, d_k + d_r]
        k = torch.cat([k_content, k_rope], dim=-1)   # [B, S, n_heads, d_k + d_r]

        # === Step 4: KV cache management ===
        if kv_cache is not None:
            # Append new latent to cache
            new_kv_cache = torch.cat([kv_cache, c_kv], dim=1)
            # Decompress full cache for attention
            k_cache = self.w_uk(kv_cache).view(B, -1, self.n_heads, self.d_k)
            k_cache_rope = self._apply_rope(
                self.w_kr(kv_cache).view(B, -1, self.n_heads, self.d_r),
                start_pos  # cache already has positions 0..start_pos-1
            )
            k = torch.cat([
                torch.cat([k_cache, k_cache_rope], dim=-1),
                k
            ], dim=1)
            v_cache = self.w_uv(kv_cache).view(B, -1, self.n_heads, self.d_v)
            v = torch.cat([v_cache, v], dim=1)
        else:
            new_kv_cache = c_kv

        # === Step 5: Compute attention ===
        # Transpose for attention: [B, n_heads, seq, dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        d_attn = self.d_k + self.d_r
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_attn)
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, v)   # [B, n_heads, S, d_v]

        # === Step 6: Output projection ===
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, -1)
        output = self.w_o(attn_output)

        return output, new_kv_cache


# === Example: Compare MLA vs MHA cache sizes ===
def compare_cache_sizes():
    """Demonstrate the KV cache savings of MLA over MHA."""
    n_heads = 128
    d_k = 128
    d_c = 512     # DeepSeek-V3 latent dim
    d_r = 64      # Decoupled rope dim per head
    seq_len = 65536   # 64K context
    bytes_per_element = 2   # FP16

    # MHA: cache K and V for all heads
    mha_cache_per_token = 2 * n_heads * d_k   # K + V
    mha_total = mha_cache_per_token * seq_len * bytes_per_element / (1024**3)

    # MLA: cache c_KV latent + decoupled rope keys for all heads
    mla_cache_per_token = d_c + n_heads * d_r
    mla_total = mla_cache_per_token * seq_len * bytes_per_element / (1024**3)

    print(f"MHA KV cache (64K ctx): {mha_total:.2f} GB per layer")
    print(f"MLA KV cache (64K ctx): {mla_total:.2f} GB per layer")
    print(f"Compression ratio: {mha_cache_per_token / mla_cache_per_token:.1f}x")
    print(f"\nFor 60 layers:")
    print(f"  MHA: {mha_total * 60:.1f} GB")
    print(f"  MLA: {mla_total * 60:.1f} GB")
    print(f"  Savings: {(mha_total - mla_total) * 60:.1f} GB")


if __name__ == "__main__":
    # Test MLA forward pass
    mla = MultiHeadLatentAttention(
        d_model=4096, n_heads=8, d_k=64, d_v=64,
        d_c=128, d_c_prime=256, d_r=32,
    )

    x = torch.randn(2, 10, 4096)   # batch=2, seq=10
    output, cache = mla(x)
    print(f"Output shape: {output.shape}")   # [2, 10, 4096]
    print(f"Cache shape: {cache.shape}")     # [2, 10, 128] - only d_c!

    # Autoregressive step
    x2 = torch.randn(2, 1, 4096)
    output2, cache2 = mla(x2, kv_cache=cache, start_pos=10)
    print(f"Output2 shape: {output2.shape}")   # [2, 1, 4096]
    print(f"Cache2 shape: {cache2.shape}")     # [2, 11, 128] - grew by 1

    print("\n--- Cache Comparison ---")
    compare_cache_sizes()
Enter fullscreen mode Exit fullscreen mode

Prerequisites

  • kv-cache - You must understand standard KV caching before understanding why MLA compresses it
  • paged-attention - MLA changes what gets paged (latent vectors, not KV pairs)
  • flash-attention - MLA's absorbed attention can be fused into FlashAttention-style kernels
  • attention-mechanism - Foundation for understanding attention computation

References

  1. DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model - Liu et al., 2024. arxiv:2405.04434 - Original MLA paper introducing the latent compression and decoupled RoPE strategy.

  2. DeepSeek-V3 Technical Report - DeepSeek-AI, 2024. arxiv:2412.19437 - Scales MLA to 671B MoE with auxiliary-loss-free routing. Details the multi-token prediction (MTP) that inspired EAGLE-style draft heads.

  3. vLLM MLA Implementation - github.com/vllm-project/vllm - Production MLA kernel with weight absorption and FlashAttention integration.

  4. FlashInfer MLA Attention - github.com/flashinfer-ai/flashinfer - Custom CUDA kernels for MLA that support both prefill and decode phases with batched latent cache.

Top comments (0)