DEV Community

vaibhav ahluwalia
vaibhav ahluwalia

Posted on

Caching Strategies for LLM Systems (Part 3): Multi-Query Attention and Memory-Efficient Decoding

Figure: Multi-Query Attention (MQA) shares a single set of keys and values across all attention heads while keeping queries head-specific, reducing KV-cache memory) during autoregressive decoding.

In Part 2, we saw how KV caching transforms autoregressive decoding by eliminating redundant attention computation. By storing keys and values from previous tokens, transformers reduce per-token compute from quadratic to linear in sequence length.

However, KV caching introduces a new bottleneck.

As models scale, KV cache memory becomes the dominant cost of inference, often exceeding model weights for long contexts. This post examines Multi-Query Attention (MQA)—an architectural modification that directly attacks this memory bottleneck by changing how attention heads share representation.


The Scaling Problem: KV Cache Grows with Head Count

In standard Multi-Head Attention (MHA), each head has its own key and value projections.

For a model with:

  • LL transformer layers
  • HH attention heads
  • sequence length TT
  • head dimension dhd_h

the KV cache memory scales as:

O(LHTdh) \mathcal{O}(L \cdot H \cdot T \cdot d_h)

KV caching removes redundant computation, but does nothing to reduce memory growth with respect to the number of heads.

For modern LLMs with 32–128 heads and long context windows, KV cache memory and bandwidth quickly become the limiting factor in inference throughput.

This leads to a fundamental question:

Do attention heads really need independent keys and values?


Multi-Query Attention: Core Architectural Change

Multi-Query Attention (MQA) answers this by imposing a strong but deliberate constraint:

All attention heads have independent queries, but share a single set of keys and values.

Formally:

Qi=XWQi,K=XWK,V=XWV Q_i = X W_{Q_i}, \quad K = X W_K, \quad V = X W_V

Each head computes:

Attentioni=softmax(QiKdh)V \text{Attention}_i = \text{softmax}\left( \frac{Q_i K^\top}{\sqrt{d_h}} \right) V

Important clarifications

  • Keys and values are shared across heads
  • Keys are not equal to values
  • WKWVW_K \neq W_V — they remain distinct projections

This single design decision collapses the KV cache size by a factor of HH .


Weight Matrix Geometry

Let the model dimension be dd .

Multi-Head Attention (MHA)

  • WQRd×(Hdh)W_Q \in \mathbb{R}^{d \times (H d_h)}
  • WKRd×(Hdh)W_K \in \mathbb{R}^{d \times (H d_h)}
  • WVRd×(Hdh)W_V \in \mathbb{R}^{d \times (H d_h)}

Multi-Query Attention (MQA)

  • WQRd×(Hdh)W_Q \in \mathbb{R}^{d \times (H d_h)}
  • WKRd×dhW_K \in \mathbb{R}^{d \times d_h}
  • WVRd×dhW_V \in \mathbb{R}^{d \times d_h}

KV Cache Memory Comparison

Attention Type KV Cache per Layer
Multi-Head Attention H×T×dhH \times T \times d_h
Multi-Query Attention 1×T×dh1 \times T \times d_h

For a 32-head model, MQA yields a 32× reduction in KV cache memory and memory bandwidth during decoding.

KV Cache Memory: MHA vs MQA (Illustrative Example)

Assume:

  • Layers (L): 80
  • Attention heads (H): 64
  • Head dimension (dₕ): 128
  • Context length (T): 2048
  • Precision: FP16 (2 bytes per element)
Attention Type KV Cache Formula KV Cache per Sequence
Multi-Head Attention (MHA) 2 × L × H × T × dₕ × 2 bytes ~1.2 GB
Multi-Query Attention (MQA) 2 × L × 1 × T × dₕ × 2 bytes ~19 MB
Reduction ~64× smaller

2 × accounts for storing both Keys and Values.


What MQA Actually Changes (Research View)

A common explanation claims:

“Most attention diversity comes from queries.”

This is incomplete and misleading.

The real story is about inductive bias and representational collapse.


Expressiveness in Multi-Head Attention

In MHA, each head has independent projections:

Qi=XWQi,Ki=XWKi,Vi=XWVi Q_i = X W_{Q_i}, \quad K_i = X W_{K_i}, \quad V_i = X W_{V_i}

This allows each head to learn a distinct attention subspace:

  • Different similarity metrics via KiK_i
  • Different retrieval semantics via ViV_i
  • Different alignment objectives via QiQ_i

From a geometric perspective, MHA spans multiple low-rank attention operators, enabling the model to represent competing relational views of the same sequence.

This is what enables heads to specialize in syntax, long-range dependency, positional bias, or coreference.


What MQA Removes

MQA enforces:

K1==KH=K,V1==VH=V K_1 = \dots = K_H = K, \quad V_1 = \dots = V_H = V

As a result:

  • All heads score relevance in the same key space
  • All heads retrieve from the same value manifold
  • Head diversity exists only through queries

This collapses the attention operator from H independent subspaces into a single shared memory with multiple query routers.


The True Inductive Bias of MQA

MQA assumes:

A single shared representation of context is sufficient,
and attention diversity mainly arises from routing, not representation.

This is a non-trivial constraint on the hypothesis space.

It reduces the rank and diversity of attention mappings, limiting the model’s ability to represent multiple incompatible interpretations simultaneously.


Where Expressiveness Is Lost

Compared to MHA, MQA loses:

  • Per-head similarity metrics
  • Per-head semantic abstractions
  • Independent relational subspaces

This directly impacts the model’s ability to:

  • View the same token from different semantic angles
  • Encode orthogonal linguistic features in parallel
  • Maintain head-level specialization

In short:

MQA reduces the model’s “point-of-view capacity.”

This follows directly from the reduced rank and shared representation imposed by MQA.


Why MQA Still Works at Scale

Despite this loss, large models trained with MQA often show minimal degradation because:

  1. Redundancy in MHA heads Many attention heads learn correlated or weakly distinct patterns.
  2. Compensation by depth and width Feed-forward layers absorb representational burden.
  3. Training adapts to the constraint Models trained from scratch with MQA learn robust shared KV spaces.
  4. Inference dominates deployment cost Memory bandwidth, not expressiveness, becomes the bottleneck.

This explains why PaLM and inference-optimized LLMs adopt MQA successfully.


Autoregressive Inference Implications

During decoding:

  1. Queries are recomputed per token
  2. Keys and values are loaded from cache
  3. Attention is computed

With MHA, step (2) loads HH KV tensors per layer.
With MQA, only one KV tensor is loaded.

This dramatically reduces:

  • Memory traffic
  • Cache pressure
  • Token latency

Summary: Compute vs Representation Trade-off

Aspect MHA MQA
Attention subspaces Many One
KV diversity Per-head Shared
Expressiveness Higher Lower
KV cache size O(HTdh)\mathcal{O}(H T d_h) O(Tdh)\mathcal{O}(T d_h)
Inference efficiency Lower Much higher

MQA is not a free optimization.
It is a deliberate architectural trade-off favoring inference scalability over maximal expressiveness.


Connect with me:

LinkedIn: (https://www.linkedin.com/in/vaibhav-ahluwalia-83887a227/)


Top comments (0)