DEV Community

Vector
Vector

Posted on

Understanding SGLang's Radix Cache, the LeetCode Way

Overview

What is Radix Cache?

When an LLM processes a prompt, it computes a Key and Value vector for every token — the KV cache. If many requests share the same system prompt, recomputing its KV cache from scratch each time is wasteful. Radix Cache stores these computed prefixes in a Radix Tree and reuses them across requests, which is one of the main reasons SGLang achieves high throughput.

Why Read mini-sglang Instead of SGLang Directly?

SGLang's codebase spans tens of thousands of lines. The Radix Cache logic is tangled with CUDA kernels, async scheduling, and multiple backends — easy to get lost in.

mini-sglang is an official reference implementation from the same team, covering all the key ideas in ~8,000 lines of clean Python. The Radix Cache is a single self-contained file of ~200 lines. Start here, then SGLang becomes readable.

This Article

radix_cache.py combines three classic algorithms. Instead of reading it cold, this article builds up to it step by step using LeetCode problems as the foundation — starting from scratch.


Step 1: LRU Cache — Understanding Eviction

LeetCode 146 - LRU Cache

The Problem

Design a cache with a fixed capacity. It supports two operations:

  • get(key) — return the value if it exists, otherwise -1
  • put(key, value) — insert or update the value; if the cache is full, evict the least recently used entry first

"Least recently used" means: among all entries currently in the cache, remove the one that was accessed (via get or put) the longest time ago.

Both operations must run in O(1) time.

Why O(1) is the Challenge

A naive solution might keep a list sorted by access time. But re-sorting on every access is O(n log n). We need something smarter.

The classic solution combines two data structures:

  • A HashMap for O(1) key lookup
  • A doubly linked list to track access order — most recently used at the tail, least recently used at the head

Every time a key is accessed, its node is moved to the tail. When eviction is needed, remove from the head.

class ListNode:
    def __init__(self, key=0, val=0):
        self.key = key
        self.val = val
        self.prev = None
        self.next = None


class LRUCache:
    def __init__(self, capacity: int):
        self.cap = capacity
        self.cache = {}  # key -> ListNode

        # dummy head and tail to avoid edge case checks
        self.head = ListNode()
        self.tail = ListNode()
        self.head.next = self.tail
        self.tail.prev = self.head

    def _remove(self, node: ListNode) -> None:
        node.prev.next = node.next
        node.next.prev = node.prev

    def _append_to_tail(self, node: ListNode) -> None:
        node.prev = self.tail.prev
        node.next = self.tail
        self.tail.prev.next = node
        self.tail.prev = node

    def get(self, key: int) -> int:
        if key not in self.cache:
            return -1
        node = self.cache[key]
        self._remove(node)
        self._append_to_tail(node)  # mark as most recently used
        return node.val

    def put(self, key: int, value: int) -> None:
        if key in self.cache:
            node = self.cache[key]
            node.val = value
            self._remove(node)
            self._append_to_tail(node)
        else:
            if len(self.cache) == self.cap:
                lru = self.head.next  # least recently used
                self._remove(lru)
                del self.cache[lru.key]
            node = ListNode(key, value)
            self.cache[key] = node
            self._append_to_tail(node)
Enter fullscreen mode Exit fullscreen mode

Mapping to radix_cache.py

In the Radix Cache, "entries" are tree nodes holding KV cache pages. The same LRU logic applies:

LRU Cache radix_cache.py
Move node to tail on access node.timestamp = time.monotonic_ns()
len(cache) == capacity → evict head evictable_size exceeds limit → call evict()
Node is being used → don't evict node.ref_count > 0 → skip during eviction
Remove from head heappop the node with the smallest timestamp

One difference: radix_cache uses a timestamp + heap instead of a linked list for LRU order. The reason is explained in Step 2.


Step 2: Kth Largest Element in a Stream — Understanding the Heap

LeetCode 703 - Kth Largest Element in a Stream

The Problem

Design a class that finds the k-th largest element in a stream of numbers. Each time a new number is added, return the k-th largest so far.

k = 3, initial = [4, 5, 8, 2]
add(3)  → 4   (stream: [2,3,4,5,8], 3rd largest = 4)
add(5)  → 5   (stream: [2,3,4,5,5,8], 3rd largest = 5)
add(10) → 5   (stream: [2,3,4,5,5,8,10], 3rd largest = 5)
Enter fullscreen mode Exit fullscreen mode

The Min-Heap Solution

A min-heap of size k always holds the k largest elements seen so far. The root (minimum of the heap) is the k-th largest.

When a new number arrives:

  • Push it into the heap
  • If the heap grows beyond size k, pop the minimum

The root is always the answer.

import heapq

class KthLargest:
    def __init__(self, k: int, nums: list[int]):
        self.k = k
        self.heap = []
        for n in nums:
            self.add(n)

    def add(self, val: int) -> int:
        heapq.heappush(self.heap, val)
        if len(self.heap) > self.k:
            heapq.heappop(self.heap)  # remove the smallest
        return self.heap[0]  # root = k-th largest
Enter fullscreen mode Exit fullscreen mode

How a Min-Heap Works

Python's heapq is a min-heap: the smallest element is always at index 0. heappush and heappop both run in O(log n).

For heapq to compare custom objects, define __lt__:

class Node:
    def __init__(self, timestamp):
        self.timestamp = timestamp

    def __lt__(self, other):
        return self.timestamp < other.timestamp  # smaller timestamp = older = higher eviction priority
Enter fullscreen mode Exit fullscreen mode

Mapping to radix_cache.py

The eviction in radix_cache.py uses the same min-heap pattern — but for finding the oldest node instead of the k-th largest:

def __lt__(self, other: RadixTreeNode) -> bool:
    return self.timestamp < other.timestamp  # older = evicted first

leave_nodes = self._collect_leave_nodes_for_evict()
heapq.heapify(leave_nodes)

while evicted_size < size:
    node = heapq.heappop(leave_nodes)  # pop the oldest leaf
    # after removing this leaf, its parent might become a new leaf
    if parent.is_leaf() and parent.ref_count == 0:
        heapq.heappush(leave_nodes, parent)  # chain eviction upward
Enter fullscreen mode Exit fullscreen mode

Why a heap instead of a linked list (unlike LRU Cache)?

In a doubly linked list, moving a node to the tail on access is O(1) because you have direct pointers. But in a tree, the "evictable" nodes are scattered across many leaf positions — there is no natural order to maintain. So at eviction time, all leaf nodes are collected and heapified. It is a batch operation rather than a per-access operation.


Step 3: Implement Trie — Understanding Prefix Trees

LeetCode 208 - Implement Trie (Prefix Tree)

The Problem

Design a data structure that stores a set of strings and supports:

  • insert(word) — add a word
  • search(word) — return True if the exact word exists
  • starts_with(prefix) — return True if any word starts with this prefix

Why Not a HashSet?

A HashSet supports search in O(1), but cannot answer starts_with efficiently — you would have to scan all stored words.

A Trie organizes words by their shared prefixes. Each node represents a character, and the path from root to a node spells out a prefix. All words sharing a prefix share the same initial path.

Insert "app", "apple", "application":

root
└── 'a'
    └── 'p'
        └── 'p' (is_end=True)         ← "app"
            └── 'l'
                ├── 'e' (is_end=True)  ← "apple"
                └── 'i'
                    └── 'c' → 'a' → 't' → 'i' → 'o' → 'n' (is_end=True)  ← "application"
Enter fullscreen mode Exit fullscreen mode
class TrieNode:
    def __init__(self):
        self.children: dict[str, TrieNode] = {}
        self.is_end: bool = False


class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, word: str) -> None:
        node = self.root
        for ch in word:
            if ch not in node.children:
                node.children[ch] = TrieNode()
            node = node.children[ch]
        node.is_end = True

    def search(self, word: str) -> bool:
        node = self.root
        for ch in word:
            if ch not in node.children:
                return False
            node = node.children[ch]
        return node.is_end

    def starts_with(self, prefix: str) -> bool:
        node = self.root
        for ch in prefix:
            if ch not in node.children:
                return False
            node = node.children[ch]
        return True  # prefix exhausted = match found, regardless of is_end
Enter fullscreen mode Exit fullscreen mode

Mapping to radix_cache.py

The Radix Cache is a Trie where "characters" are token ids and is_end is replaced by ref_count. The _tree_walk method is the core traversal, equivalent to the for ch in word loops above — it walks from the root as far as the input token ids match.


Step 4: Radix Tree — Compressing the Trie

A Trie creates one node per character. For long strings (like a 1024-token prompt), this means 1024 nodes in a single chain. A Radix Tree compresses chains of single-child nodes into one node with a multi-character label:

Trie for "application":
root → 'a' → 'p' → 'p' → 'l' → 'i' → 'c' → 'a' → 't' → 'i' → 'o' → 'n'  (11 nodes)

Radix Tree for "application":
root → "application"  (1 node)
Enter fullscreen mode Exit fullscreen mode

When a second word is inserted that shares a prefix, the node is split:

Insert "app" into a tree that already has "application":

Before:  root → "application" (is_end=True)
After:   root → "app" (is_end=True) → "lication" (is_end=True)
Enter fullscreen mode Exit fullscreen mode

Implementation

class RadixTreeNode:
    def __init__(self, label: str = "", is_end: bool = False):
        self.children: dict[str, RadixTreeNode] = {}  # first char of label -> child
        self.label: str = label
        self.is_end: bool = is_end


class RadixTree:
    def __init__(self):
        self.root = RadixTreeNode()

    def _common_prefix_len(self, s1: str, s2: str) -> int:
        for i in range(min(len(s1), len(s2))):
            if s1[i] != s2[i]:
                return i
        return min(len(s1), len(s2))

    def _insert_node(self, node: RadixTreeNode, word: str) -> None:
        k = self._common_prefix_len(node.label, word)
        n = len(node.label)

        if k == n:
            # this edge is fully matched
            if len(word) == n:
                node.is_end = True  # exact match, mark end
            else:
                tail = word[n:]
                if tail[0] in node.children:
                    self._insert_node(node.children[tail[0]], tail)
                else:
                    new_node = RadixTreeNode(tail, True)
                    node.children[tail[0]] = new_node
        else:
            # partial match: split this node at position k
            #
            # Before: node.label = "application"
            # Insert "app" (k=3):
            #   new_mid.label = "app"       (the common prefix)
            #   node.label    = "lication"  (the remainder, stays as child)
            #
            new_mid = RadixTreeNode(node.label[:k], False)
            node.label = node.label[k:]
            new_mid.children[node.label[0]] = node  # old node becomes child

            # now insert the remainder of word under new_mid
            word_tail = word[k:]
            if word_tail:
                new_leaf = RadixTreeNode(word_tail, True)
                new_mid.children[word_tail[0]] = new_leaf
            else:
                new_mid.is_end = True  # word == common prefix

            # replace the current node in its parent
            # (handled by the caller via children dict)
            # we signal this by modifying node in-place and returning new_mid
            # simpler: just restructure in insert()
            node.label = new_mid.label
            node.is_end = new_mid.is_end
            old_children = node.children
            node.children = new_mid.children
            # restore the old node under new key
            old_first = list(old_children.keys())[0] if old_children else None
            if old_first:
                child = node.children.get(old_first)
                if child:
                    child.children = old_children

    def insert(self, word: str) -> None:
        if not word:
            return
        if word[0] in self.root.children:
            self._insert_node(self.root.children[word[0]], word)
        else:
            new_node = RadixTreeNode(word, True)
            self.root.children[word[0]] = new_node

    def search(self, word: str) -> bool:
        node = self.root
        while word:
            child = node.children.get(word[0])
            if child is None:
                return False
            k = self._common_prefix_len(child.label, word)
            if k < len(child.label):
                return False  # partial match only
            word = word[k:]
            node = child
        return node.is_end

    def starts_with(self, prefix: str) -> bool:
        node = self.root
        while prefix:
            child = node.children.get(prefix[0])
            if child is None:
                return False
            k = self._common_prefix_len(child.label, prefix)
            if k < len(child.label) and k < len(prefix):
                return False
            prefix = prefix[k:]
            node = child
        return True
Enter fullscreen mode Exit fullscreen mode

Key Design Decision: Why First-Character Keys?

The children dict is keyed by the first character of each child's label. This guarantees that at most one child can match any given input character — so lookup is always O(1), no scanning needed.


Step 5: radix_cache.py — Putting It All Together

Now the 200 lines make sense. Every concept maps directly:

Radix Tree radix_cache.py
label (string) _key (token id tensor)
node payload _value (KV cache page indices)
is_end ref_count (0 = evictable, >0 = in use)
_common_prefix_len fast_compare_key (same logic, runs on GPU)
node split on insert split_at() called inside _tree_walk
LRU eviction evict() using min-heap on timestamp

What _value stores:

The actual KV cache tensors (the large matrices) live in MHAKVCache. The Radix Tree only stores page indices — it is an index, not the data store itself. Think of it as a library card catalog: the card tells you which shelf the book is on; the book itself is elsewhere.

Page alignment:

KV cache is allocated in fixed-size pages (page_size tokens per page). A partial page cannot be reused, so prefix matching only counts at complete page boundaries:

match_len = align_down(match_len, self.page_size)
Enter fullscreen mode Exit fullscreen mode

Split on lookup, not just insert:

Unlike our Radix Tree where splitting only happens during insert, radix_cache.py splits nodes during _tree_walk (which is called by both match_prefix and insert_prefix). This eagerly prepares the tree for future matches.


The Full Learning Path

LeetCode 146 (LRU Cache)
     eviction policy, O(1) access tracking, lock/unlock semantics
    
LeetCode 703 (Kth Largest in Stream)
     min-heap, __lt__ for custom comparison, heap-based priority eviction
    
LeetCode 208 (Implement Trie)
     prefix tree structure, children dict, is_end flag, starts_with vs search
    
Implement Radix Tree (custom exercise)
     compressed trie, edge labels, node splitting, O(1) child lookup
    
radix_cache.py (~200 lines)
     all of the above applied to token id tensors with KV cache page management
Enter fullscreen mode Exit fullscreen mode

Each step introduces exactly one new idea. By the time you reach radix_cache.py, there is nothing unfamiliar — just the same patterns applied to tensors instead of strings, and page indices instead of boolean flags.

Top comments (0)