DEV Community

Kotcherla Murali Krishna
Kotcherla Murali Krishna

Posted on

What Happens Inside an LLM During Inference: Tokens, KV Cache, and GPU Execution Explained

What Happens Inside an LLM During Inference: Tokens, KV Cache, and GPU Execution Explained

You type a prompt. You hit Enter. In under two seconds, a response starts streaming back — word by word, almost like a human typing in real time.

But what actually happens between your keypress and that first token appearing on screen?

Inside the server, a sequence of events unfolds that involves tokenization, billions of matrix multiplications, carefully scheduled GPU kernel launches, a memory system called the KV cache, and a probabilistic sampling process — all happening in microseconds per token.

This article takes you inside the machine. We’ll trace a single inference request from raw text to streamed response, layer by layer, operation by operation. By the end, you’ll understand not just what LLMs do, but how modern systems execute them at scale — and why it’s so expensive.

Tokenization Explained

Before any neural network sees your input, it has to become numbers. But it doesn’t convert character by character — it converts subword chunks called tokens.

Modern LLMs use Byte Pair Encoding (BPE) or variants like SentencePiece. The vocabulary (typically 32K–128K tokens) is learned during training by iteratively merging frequent byte pairs until a fixed vocab size is reached.

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
text = "What happens inside an LLM during inference?"
tokens = tokenizer(text)
print(tokens["input_ids"])
# [128000, 3923, 8741, 4871, 459, 445, 11237, 2391, 45478, 30]
print(tokenizer.convert_ids_to_tokens(tokens["input_ids"]))
# ['<|begin_of_text|>', 'What', ' happens', ' inside', ' an', ' L', 'LM', ' during', ' inference', '?']
Enter fullscreen mode Exit fullscreen mode

💡 Key insight: “LLM” becomes two tokens: L and LM. Tokenization is not intuitive — it's frequency-driven. Common words are single tokens; rare or technical terms split.

The output is a sequence of integer IDs — the true input to the model. Token count determines compute cost directly: more tokens = more compute.

Token count determines compute cost directly

Embeddings and Vector Representations

Token IDs are integers. Neural networks need continuous-valued vectors. The embedding layer is a giant lookup table: a matrix of shape [vocab_size × d_model] where d_model is typically 4096 (for 7B-class models) or 8192 (for 70B+).

import torch
import torch.nn as nn

vocab_size = 128_000
d_model = 4096

embedding = nn.Embedding(vocab_size, d_model)

# Token IDs → embedding vectors
token_ids = torch.tensor([3923, 8741, 4871, 459])
vectors = embedding(token_ids) # shape: [4, 4096]

Enter fullscreen mode Exit fullscreen mode

Each token becomes a 4096-dimensional vector floating in a high-dimensional space. Semantically similar tokens cluster together — “king” and “queen” are nearby; “inference” and “forward pass” are closer than “inference” and “potato.”

Positional encoding is then added to inject sequence order, since the transformer itself is permutation-invariant:

# Rotary Positional Embedding (RoPE) — used in LLaMA, Mistral, Gemma
# Encodes position by rotating query/key vectors in frequency space
# No separate positional embedding matrix needed
Enter fullscreen mode Exit fullscreen mode

🔑 Modern LLMs use RoPE (Rotary Position Embedding) instead of the classic sinusoidal embeddings from the original “Attention is All You Need” paper. RoPE encodes relative position directly into the attention computation, allowing better generalization to longer sequences.

Transformer Layers Explained Visually

A transformer is a stack of identical layers — GPT-3 has 96, LLaMA-3 8B has 32, LLaMA-3 70B has 80. Each layer has two main sublayers:

  1. Multi-Head Self-Attention (MHSA)
  2. Feed-Forward Network (FFN)

With residual connections and layer normalization (typically pre-norm in modern models) wrapping each.

residual connections and layer normalization

💡 Why residual connections? They allow gradients to flow directly during training, preventing vanishing gradients. During inference, they mean each layer refines the representation rather than replacing it — the model builds understanding incrementally.

Modern LLMs replace standard LayerNorm with RMSNorm (root mean square normalization) — computationally cheaper, empirically equivalent. The FFN uses SwiGLU activation instead of classic ReLU, adding a gating mechanism that improves expressivity.

Self-Attention Mechanism

Self-attention is the heart of the transformer. It lets every token look at every other token and decide what to attend to.

For each token, three vectors are computed via learned linear projections:

  • Q (Query) — “What am I looking for?”
  • K (Key) — “What do I contain?”
  • V (Value) — “What do I output if attended to?”
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.size(-1)

    # Attention scores: how much each token attends to every other
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Softmax over keys dimension → attention weights
    weights = F.softmax(scores, dim=-1)

    # Weighted sum of values
    return torch.matmul(weights, V), weights

# Multi-head: run attention H times in parallel, concat results
# Each head learns different relationship types
Enter fullscreen mode Exit fullscreen mode

Multi-head attention runs this H times in parallel (e.g., 32 heads for a 7B model), each with a smaller dimension d_k = d_model / H. Each head specializes — some learn syntactic relationships, others semantic, others coreference.

Multi-head attention

⚠️ Complexity: Self-attention is O(n²) in sequence length — doubling the context quadruples the attention computation. This is why long-context models (128K+ tokens) need specialized attention algorithms like FlashAttention.

FlashAttention (Dao et al., 2022) restructures the computation to stay within GPU SRAM rather than repeatedly reading/writing to HBM, achieving 2–4× speedup without changing the math.

Matrix Multiplication on GPUs

Every linear projection in a transformer (Q, K, V projections, FFN weights, output projections) is a matrix multiplication (matmul). For a batch of tokens and a weight matrix:

Output [batch × seq × d_out] = Input [batch × seq × d_in] × W [d_in × d_out]
Enter fullscreen mode Exit fullscreen mode

For LLaMA-3 8B with d_model=4096, this is a 4096×4096 matmul per projection — millions of multiply-accumulate operations.

GPUs are built for exactly this. An H100 delivers ~1,979 TFLOPS of BF16 tensor core throughput. The secret: tensor cores  — specialized hardware that computes 4×4 or 8×4 matrix fragments in a single clock cycle.

GPU

Memory bandwidth is the bottleneck, not compute. During token generation (decode phase), the GPU loads enormous weight matrices from HBM for each token — but performs relatively few FLOPs per byte loaded. This is called being memory-bound (low arithmetic intensity).

📊 Arithmetic Intensity = FLOPs / bytes accessed

Prefill phase: ~high intensity → compute-bound

Decode phase: ~low intensity → memory-bound

This asymmetry is why prefill and decode need different optimization strategies.

CUDA Kernels and Tensor Operations

When PyTorch executes torch.matmul(), it doesn't run one monolithic computation — it dispatches a CUDA kernel : a compiled GPU function that runs in parallel across thousands of threads.

import torch

A = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)
B = torch.randn(4096, 4096, device='cuda', dtype=torch.bfloat16)

# This dispatches a cuBLAS GEMM kernel under the hood
C = torch.matmul(A, B) # ~137 billion FLOPs in milliseconds
Enter fullscreen mode Exit fullscreen mode

The GPU execution stack for a single matmul:

PyTorch op → Dispatches cuBLAS → Selects optimal GEMM kernel 
→ Divides matrix into tiles → Assigns tiles to Streaming Multiprocessors (SMs)
→ Each SM loads tile to shared memory (SRAM)
→ Tensor cores compute 4×4 fragments
→ Results written back to HBM
Enter fullscreen mode Exit fullscreen mode

Kernel fusion is a critical optimization: instead of launching separate kernels for attention scores, softmax, and matmul, fused kernels like FlashAttention do it all in one pass — dramatically reducing HBM traffic.

# Flash Attention via PyTorch SDPA (scaled dot product attention)
from torch.nn.functional import scaled_dot_product_attention

# Automatically uses FlashAttention backend when available
output = scaled_dot_product_attention(Q, K, V, is_causal=True)
Enter fullscreen mode Exit fullscreen mode

KV Cache During Inference

This is one of the most important engineering decisions in LLM inference.

During generation, the model processes the full prompt once — but for each new token generated, it only needs to run attention for that one new token against all prior tokens. Without caching, you’d recompute K and V vectors for the entire history on every step.

The KV cache stores the Key and Value tensors for all previously processed tokens:

# Simplified KV cache structure
kv_cache = {
    layer_idx: {
        "k": torch.zeros(batch_size, num_heads, max_seq_len, head_dim),
        "v": torch.zeros(batch_size, num_heads, max_seq_len, head_dim),
        "length": 0 # current filled position
    }
    for layer_idx in range(num_layers)
}

def attention_with_cache(Q_new, layer_idx, cache):
    pos = cache[layer_idx]["length"]

    # Append new K, V
    cache[layer_idx]["k"][:, :, pos, :] = K_new
    cache[layer_idx]["v"][:, :, pos, :] = V_new
    cache[layer_idx]["length"] += 1

    # Attend over full cached history
    K_full = cache[layer_idx]["k"][:, :, :pos+1, :]
    V_full = cache[layer_idx]["v"][:, :, :pos+1, :]

    return scaled_dot_product_attention(Q_new, K_full, V_full)
Enter fullscreen mode Exit fullscreen mode

💾 Memory cost of KV cache:

2 × num_layers × num_heads × head_dim × seq_len × bytes_per_element

For LLaMA-3 70B (BF16): ~10 GB for a single 8K-token sequence.

This is why long-context inference is
extremely memory-intensive.

Grouped Query Attention (GQA) — used in LLaMA-3, Mistral, and others — reduces this by sharing K/V heads across groups of Q heads, cutting KV cache size by 4–8×.

Prefill vs Decode Phase

Inference has two fundamentally different phases with different compute profiles:

Inference has two fundamentally different phases with different compute profiles

Prefill processes all prompt tokens in one large parallel forward pass. It’s compute-intensive — the GPU’s tensor cores are fully utilized. For a 1000-token prompt, this might take 50–200ms on an H100.

Decode generates tokens one by one. Each decode step:

  1. Runs a forward pass for one new token
  2. Reads the full KV cache (all prior tokens) from HBM
  3. Appends new K/V entries to cache
  4. Samples one token from the output distribution

Decode is slow because loading the KV cache and model weights per step is memory-bandwidth-limited. An H100 generates ~100–200 tokens/second for a 70B model — far below its theoretical FLOP peak.

Time To First Token (TTFT) is dominated by prefill.

Inter-Token Latency (ITL) is determined by decode throughput.

These are the two key SLAs in production inference systems.

Token Sampling Strategies

After the final transformer layer, a linear projection maps d_model → vocab_size, followed by softmax to produce a probability distribution over all tokens. Then sampling:

import torch
import torch.nn.functional as F

def sample_token(logits, temperature=0.8, top_p=0.9, top_k=50):
    # Temperature scaling - higher = more random
    logits = logits / temperature

    # Top-K filtering - only consider top K tokens
    if top_k > 0:
        top_k_values, _ = torch.topk(logits, top_k)
        threshold = top_k_values[..., -1, None]
        logits = logits.masked_fill(logits < threshold, float('-inf'))

    # Top-P (nucleus) sampling - smallest set summing to probability p
    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative prob above threshold
        sorted_indices_to_remove = cumulative_probs - F.softmax(sorted_logits, dim=-1) > top_p
        sorted_logits[sorted_indices_to_remove] = float('-inf')

        logits = torch.zeros_like(logits).scatter_(0, sorted_indices, sorted_logits)

    # Sample from filtered distribution
    probs = F.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)
Enter fullscreen mode Exit fullscreen mode

Token Sampling Strategies

🎲 Practical defaults: Temperature 0.6–0.8 for chat, Top-P 0.9, Top-K 40–100. Coding tasks often use lower temperature (0.2–0.4) for more deterministic outputs. Use temperature=0 for greedy decoding (deterministic reproduction).

Streaming Responses

When ChatGPT “types” a response, it’s not buffering the full answer — it’s streaming each token as it’s generated, using Server-Sent Events (SSE).

# FastAPI streaming endpoint (simplified)
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
import asyncio

app = FastAPI()

async def token_stream(prompt: str):
    async for token in model.generate_stream(prompt):
        # SSE format: "data: {token}\n\n"
        yield f"data: {token}\n\n"
    yield "data: [DONE]\n\n"

@app.post("/generate")
async def generate(prompt: str):
    return StreamingResponse(
        token_stream(prompt),
        media_type="text/event-stream"
    )
Enter fullscreen mode Exit fullscreen mode

On the client side:

const response = await fetch('/generate', { method: 'POST', body: prompt });
const reader = response.body.getReader();
const decoder = new TextDecoder();

while (true) {
    const { done, value } = await reader.read();
    if (done) break;

    const chunk = decoder.decode(value);
    const lines = chunk.split('\n\n');

    for (const line of lines) {
        if (line.startsWith('data: ') && line !== 'data: [DONE]') {
            const token = line.slice(6);
            displayToken(token); // Append to UI
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

The user experience of “streaming” emerges from the model’s decode loop — each iteration generates one token (or a few), which is immediately sent over the wire. The GPU is generating while the client is rendering.

Inference Engines

Raw PyTorch is not how production LLMs are served. Specialized inference engines add layers of optimization:

vLLM

The most widely deployed open-source inference engine. Key innovations:

  • PagedAttention: Inspired by OS virtual memory paging — the KV cache is divided into fixed-size blocks (“pages”), allocated dynamically. Eliminates internal fragmentation, enabling 2–4× higher throughput than naive implementations.
  • Continuous batching (detailed below)
  • Tensor parallelism across multiple GPUs
from vllm import LLM, SamplingParams

llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct", tensor_parallel_size=2)

sampling_params = SamplingParams(temperature=0.7, top_p=0.9, max_tokens=512)
outputs = llm.generate(["Explain LLM inference in detail"], sampling_params)

for output in outputs:
    print(output.outputs[0].text)
Enter fullscreen mode Exit fullscreen mode

TensorRT-LLM

NVIDIA’s production inference engine. Compiles models to optimized TensorRT engines:

  • Kernel fusion (combines multiple ops into single CUDA kernels)
  • INT8/FP8 quantization with calibration
  • In-flight batching
  • Multi-GPU tensor + pipeline parallelism
  • Specifically tuned for NVIDIA hardware (A100/H100)

Hugging Face TGI (Text Generation Inference)

The most developer-accessible production server:

  • Flash Attention and Paged Attention backends
  • Continuous batching
  • Token streaming via gRPC and HTTP
  • Used internally at HuggingFace for the Inference API
# Deploy LLaMA-3 8B with TGI
docker run --gpus all -p 8080:80 \
  -v $PWD/models:/data \
  ghcr.io/huggingface/text-generation-inference:2.0 \
  --model-id meta-llama/Meta-Llama-3-8B-Instruct \
  --num-shard 1 \
  --max-input-length 4096 \
  --max-total-tokens 8192
Enter fullscreen mode Exit fullscreen mode

Continuous Batching

Traditional static batching waits for a full batch of requests before starting inference — meaning requests that arrive mid-generation wait for the whole batch to finish. Terrible for latency.

Continuous batching (also called in-flight batching) processes requests at the iteration level:

Continuous batching

In continuous batching, every decode iteration can add new requests and retire finished ones. GPU utilization stays high regardless of heterogeneous sequence lengths. vLLM, TGI, and TensorRT-LLM all implement this.

📈 Throughput impact: Continuous batching typically achieves 5–10× higher throughput than static batching at the same latency SLA.

GPU Memory Bottlenecks

GPU memory is the scarce resource in LLM inference. An H100 SXM5 has 80 GB of HBM3e. For a 70B parameter model in BF16:

Model weights: 70B × 2 bytes = 140 GB → needs 2× H100s minimum (tensor parallel)
KV cache (8K ctx): ~10 GB per request
Activations: ~1–2 GB per request
CUDA overhead: ~1–2 GB
Enter fullscreen mode Exit fullscreen mode

Quantization is the primary lever for fitting larger models:

Quantization

GPTQ (post-training quantization) and AWQ (activation-aware weight quantization) are the dominant 4-bit quantization methods for inference — fitting a 70B model on a single H100 while preserving most quality.

Speculative decoding is another technique: a smaller “draft” model generates candidate tokens quickly; the large “verifier” model checks them in parallel. If accepted, you get multiple tokens per large-model forward pass — reducing the memory-bandwidth bottleneck of decode.

Scaling Inference to Millions of Users

Serving a single model to millions of concurrent users requires distributed systems engineering layered on top of GPU optimization:

Scaling Inference to Millions of Users

Key scaling strategies:

Tensor Parallelism (TP): Split each weight matrix across multiple GPUs. Each GPU computes a shard; results are all-reduced via NVLink. Used within a node.

Pipeline Parallelism (PP): Split transformer layers across GPUs — GPU 1 runs layers 1–20, GPU 2 runs layers 21–40, etc. Used across nodes.

Data Parallelism: Run multiple full model replicas, each serving a subset of requests. The simplest form — scale out by adding replicas.

Prompt caching / prefix caching: For requests sharing a long system prompt (e.g., all users of the same chatbot share the same 2000-token system prompt), the KV cache for that prefix is computed once and reused. Anthropic, OpenAI, and Google all offer this as a feature — reducing cost and latency for shared prefixes.

KV cache offloading: Move KV cache entries for paused requests from GPU HBM to CPU RAM or NVMe SSD, freeing GPU memory for active requests. Trades latency for capacity.

🏗️ Production deployment stack (typical):

Kubernetes + GPU operator → vLLM / TRT-LLM serving pods → Prometheus + Grafana monitoring → OpenTelemetry tracing → Autoscaling on queue depth

Final Architecture Walkthrough

Let’s trace a single request end-to-end through a production LLM inference system:

Trace a single request end-to-end through a production LLM inference system

Timing breakdown for a 1000-token prompt → 200-token response on H100:

Timing breakdown

🔮 The big bet: The industry is moving toward inference-time compute as a primary scaling axis — not just training. This means inference systems will bear an increasingly large share of the total AI compute budget, making every optimization described in this article more important, not less.

Closing

When you send a prompt to an LLM, you’re setting off a cascade of engineering decisions made by hundreds of researchers and engineers — from the BPE tokenizer vocabulary to the CUDA kernel that fuses attention, from the paging algorithm managing KV cache memory to the SSE stream delivering tokens to your browser.

The transformer is elegant math. Production inference is brutal systems engineering.

Understanding both levels — the math and the metal — is what separates engineers who use LLMs from engineers who can build, optimize, and scale them.

Top comments (0)