The Problem Nobody Warns You About
Most LLM inference guides talk about model quantization and batch sizes. What they don't tell you: the KV cache will eat your VRAM before the model weights do.
I loaded Llama 2 70B on a single RTX 4090 (24GB VRAM) using 4-bit quantization. The model fit. Started inference with a batch size of 1, context length 2048. Three requests in, I hit OOM. The culprit? Not the 70 billion parameters — the key-value cache from the attention mechanism.
Here's the math. For transformer models, each token generates a key vector $k$ and value vector $v$ that must be stored for all previous tokens in the sequence. For a model with $L$ layers, hidden dimension $d_h$, and sequence length $n$, the KV cache memory is:
$$M_{\text{KV}} = 2 \cdot L \cdot n \cdot d_h \cdot b \cdot \text{bytes}$$
where $b$ is batch size and bytes depends on precision (2 for FP16, 4 for FP32). For Llama 2 70B ($L=80$, $d_h=8192$), a single sequence at $n=2048$ with FP16 precision requires:
$$M_{\text{KV}} = 2 \cdot 80 \cdot 2048 \cdot 8192 \cdot 1 \cdot 2 = 5.24 \text{ GB}$$
That's more than a third of my available VRAM, and I haven't even started batching yet.
Continue reading the full article on TildAlice

Top comments (0)