DEV Community

TildAlice
TildAlice

Posted on • Originally published at tildalice.io

Paper Review: GQA — Grouped Query Attention for Faster LLM Inference

Why Multi-Head Attention Has a Memory Problem

Here's a number that might surprise you: in a 65B parameter LLM with 80 attention heads, the key-value cache during inference can eat up 40GB of memory just to hold context. That's before you've even loaded the model weights.

You can read the original paper here (Ainslie et al., 2023).

Multi-Head Attention (MHA) — the backbone of every Transformer since Vaswani et al. (2017) — stores separate key and value projections for each head. During autoregressive decoding, you cache these KV pairs for every token you've generated. With 80 heads and a context length of 8K tokens, you're caching $80 \times 8192 \times d_{head} \times 2$ floats per layer. Multiply by 80 layers and suddenly your GPU is sweating.

Multi-Query Attention (MQA), proposed by Shazeer (2019), took the nuclear option: share a single key-value head across all query heads. Memory footprint drops by 8-80x depending on your head count. Inference flies. But there's a catch — quality degrades. The authors of GQA measured a non-trivial accuracy gap when MQA was applied to larger models. You save memory but lose capability.

GQA sits in the middle. Instead of N heads (MHA) or 1 head (MQA), you get G groups. Each group of query heads shares one KV head. It's embarrassingly simple.

$$\text{Memory}{\text{GQA}} = \frac{G}{H} \times \text{Memory}{\text{MHA}}$$


Continue reading the full article on TildAlice

Top comments (0)