IndexCache: Killing the Indexer's O(NL²) Bottleneck in DeepSeek Sparse Attention
Notes from my notebook on GLM-5.2 / DeepSeek Sparse Attention (DSA), reconstructed from the IndexCache paper (Bai, Dong et al., Tsinghua + Z.ai, 2026) — the mechanism behind GLM-5.2's "IndexShare."
1. Why this exists — the bottleneck nobody talks about
DSA's whole pitch is: don't do full O(L²) attention, instead let a cheap lightning indexer look at all preceding tokens and pick the top-k (k=2048) that actually matter, then do real attention only on those. That drops core attention from O(L²) → O(Lk).
Great — except I missed this the first time I read DSA: the indexer itself is still O(L²). It has to score every preceding token against the query to decide who's in the top-k. So across N layers you've traded one O(L²) cost for N separate O(L²) costs — total O(NL²). At long context this indexer becomes the dominant cost, not the attention it was supposed to fix.
Adding the indexer is "DSA on steroids" because it kills DSA's one real bottleneck (full attention) — but in doing so, it grows its own. The indexer is cheap per-FLOP (few heads, low-rank, FP8) but it still runs at every single layer.
The fix the paper proposes isn't a smarter indexer — it's don't run it every layer at all.
2. The core insight: adjacent layers pick almost the same tokens
If you measure pairwise overlap between the top-k token sets selected by each layer's indexer, adjacent layers share 70–100% of their picks. The heatmap even shows block structure — clusters of layers (e.g. layers 3–5, 17–30, etc.) that all converge on roughly the same "important" tokens.
So most of the O(NL²) indexer cost is redundant computation of the same answer.
This motivates IndexCache: split the N layers into two roles —
- F (Full) layers — run their own indexer, compute fresh top-k, cache it.
- S (Shared) layers — skip the indexer entirely, just reuse the nearest preceding F layer's cached top-k.
The first layer is always F (has to seed the cache).
Inference loop comparison
Standard DSA:
for l = 1 to N:
I⁽ˡ⁾ ← Indexer_l(X)
T⁽ˡ⁾ ← top-k(I⁽ˡ⁾)
X ← SparseAttn_l(X, T⁽ˡ⁾)
X ← FFN_l(X) # + norm, residual
IndexCache:
for l = 1 to N:
if c_l == F:
I⁽ˡ⁾ ← Indexer_l(X)
T⁽ˡ⁾ ← top-k(I⁽ˡ⁾)
T_cache ← T⁽ˡ⁾
else: # c_l == S
T⁽ˡ⁾ ← T_cache # reuse
X ← SparseAttn_l(X, T⁽ˡ⁾)
X ← FFN_l(X)
T_cache is just a temp buffer holding the current index tensor — it gets overwritten at every F layer, so it adds zero extra GPU memory over standard DSA. The only real change to the loop is one if/else branch. That's the whole elegance of this method — no architecture surgery, just a routing decision.
3. Finding top-k (the indexer mechanics, cleaned up)
This part is just DSA's own lightning indexer, for reference since it's what gets shared:
Compatibility between query q and each candidate position i, per block/head:
-
s_i = q · W_i + b_i— raw score for position i -
g_i = max(0, s_i)— ReLU gate (this is the "lightning" part: cheap, no softmax needed before selection) -
Top-k = argmax_i(g_i)over all i — pick the k highest-scoring positions
This sits underneath MLA (Multi-head Latent Attention). The reason MLA matters here: instead of every head keeping its own full KV, MLA squeezes all heads' KV into one shared low-rank latent vector — latent = x·W^D (down-projection). The indexer scores against this compressed representation, which is part of why it's so much cheaper per-FLOP than the main attention.
4. Two ways to find the F/S pattern
The question is: which layers do you keep as F? Two answers, training-free and training-aware — and notably, the "obvious" third answer (similarity-based) fails. Order of discovery matters here, so I'm keeping it in the order the paper actually tried things.
4.1 Why the naive static pattern fails
The dumbest idea: just alternate uniformly, e.g. F S S S F S S S ... (1 F every 4 layers). This doesn't work well. Why: indexer "importance" is not uniform across depth. Some layers — especially early/transitional ones — are way more sensitive to losing their own indexer than others. A fixed period can easily land an F on a redundant layer and an S on a critical one. You need the model (or data) to tell you which layers are safe to share.
4.2 Training-free IndexCache — greedy search
No weight updates at all. Just:
- Start with all layers = F.
- Pool of candidate layers =
{2, 3, ..., N}(layer 1 is always F — has to seed the cache). - Pick a small calibration dataset (cached batches from training data — same batches reused for every candidate evaluation, so loss differences come purely from the pattern, not data noise).
- For each step: try flipping every remaining F layer to S, one at a time, measure resulting LM loss on the calibration set, and commit whichever flip increases the loss the least.
- Repeat for K steps, where K = target number of S layers (e.g. K = 3N/4 to keep only 1/4 of indexers).
This is literally a greedy "convert layers one-by-one, always pick the one with minimum loss increase" search — full search is O(N²) forward passes, but if you've got pipeline-parallel stages (P of them), you can split layers into P blocks and search them in parallel, cutting total passes by roughly P×.
What you get out of this (empirically, from the paper's 30B DSA model + GLM-5):
- The searched pattern reliably beats uniform interleaving at the same retention ratio.
- The per-step loss curve has a visible kink — first ~20 layers are "easy" (cheap to convert), the rest are "critical" (loss jumps fast). So there's a real ordering of indexer importance baked into the model, not noise.
- This ranking is stable across different calibration sets — it's an intrinsic property of the trained model, not a calibration artifact.
- Retaining only 1/4 of indexers (75% removed) with the searched pattern matches the original model's downstream performance almost exactly.
4.3 Training-aware IndexCache — multi-layer distillation
If you're willing to retrain (continued pretraining, not from scratch), you can go further: force the indexer to actually learn to serve multiple layers, instead of hoping a pattern search finds layers that happen to tolerate sharing.
Standard DSA already trains each layer's indexer via KL-divergence distillation against that same layer's aggregated attention distribution p_t⁽ˡ⁾. The extension here: if layer ℓ is F and serves S layers ℓ+1, ..., ℓ+m, train its indexer against all of them jointly:
L_multi = Σ_{j=0}^{m} [ 1/(m+1) · Σ_t D_KL( p_t^(ℓ+j) || q_t^(ℓ) ) ]
where:
-
q_t⁽ˡ⁾= indexer's own output distribution (softmax of its scores) at layer ℓ -
p_t⁽ˡ⁾= the real aggregated attention distribution at layer ℓ (averaged across heads) -
1/(m+1)= just averaging over however many layers reuse this same index
Important note (training detail I almost missed): you don't do this from random init. A randomly initialized model's attention distribution has no real structure yet — forcing the indexer to chase an undefined target just injects noise. So this is always done as continued pretraining / fine-tuning on top of an already-trained DSA model, in two stages: a frozen "dense warm-up" that trains only the indexer, then a "sparse training" phase that activates top-k and trains everything jointly.
5. The proof: L_multi and L_avg give the exact same gradient
This is the part of my notes that was the messiest, so here's the clean derivation.
Define the averaged target distribution across the m+1 served layers:
p̄_t = Σ_{j=0}^{m} [ 1/(m+1) · p_t^(ℓ+j) ]
and the single-target loss using that averaged target:
L_avg = Σ_t D_KL( p̄_t || q_t^(ℓ) )
Claim: ∇_θ L_multi = ∇_θ L_avg.
Proof. The key trick: in D_KL(p || q), only q depends on the trainable parameters θ (p is just data — the real attention distribution, treated as a fixed target with stop-gradient). So when you differentiate KL divergence w.r.t. θ, the entropy term of p (which doesn't depend on θ) vanishes entirely. What's left is just the cross-entropy term:
∇_θ D_KL(p || q_t^(ℓ)) = -∇_θ Σ_s p(s) · log q_t^(ℓ)(s)
This is the step I got stuck on in my notebook — I wasn't sure why only the log q term survives. The answer is straightforward once you write KL out fully:
D_KL(p || q) = Σ_s p(s) log p(s) − Σ_s p(s) log q(s)
└──────┬──────┘ └───────┬───────┘
entropy term of p cross-entropy term
(no θ dependence, (only term with θ,
gradient = 0) via q = softmax(indexer))
Now apply this to L_multi:
∇_θ L_multi = - Σ_{j=0}^{m} [1/(m+1)] Σ_t ∇_θ Σ_s p_t^(ℓ+j)(s) log q_t^(ℓ)(s)
Since the sum over j and the sum over s are both linear, swap their order and pull the constant log term out:
= - Σ_t ∇_θ Σ_s [ Σ_{j=0}^{m} (1/(m+1)) p_t^(ℓ+j)(s) ] · log q_t^(ℓ)(s)
└──────────────────┬──────────────────┘
= p̄_t(s)
= - Σ_t ∇_θ Σ_s p̄_t(s) log q_t^(ℓ)(s)
= ∇_θ L_avg. ∎
So averaging before taking KL and summing the KL terms after are mathematically identical at the gradient level — the indexer ends up being pulled toward the centroid of all the attention distributions it serves, not toward any one layer.
Then why use L_multi in practice if they're equivalent? Pure memory/engineering reason: with L_multi, each S layer only needs to send its own predicted q value backward. With L_avg, you'd need to pass both p and q for every served layer to compute the average first — which means extra memory overhead and extra runtime cost for no actual gain, since the gradient comes out identical either way.
My takeaway after sitting with this for a while: a lot of "novel" architecture papers ultimately reduce to "design the right loss function for what you want, and let the network figure out the rest." This derivation is a good concrete example — the multi-layer trick isn't a new optimization method, it's just an equivalent (and cheaper) way to write the same gradient.
6. Performance (30B DSA model, 200K context)
| Metric | Standard DSA | + IndexCache (1/4 retained) |
|---|---|---|
| Prefill latency | 19.5 s | 10.7 s (1.82× speedup) |
| Decode throughput (per request) | 58 tok/s | 86 tok/s (1.48× speedup) |
Why the training-aware version works where uniform static doesn't: the greedy search has to avoid sensitive layers because the model was never trained to tolerate sharing — without retraining, certain layers are tightly coupled to their own indexer's exact top-k, and feeding them someone else's indices causes a distribution shift that breaks things. Once you train with the multi-layer distillation loss, the S layers themselves learn to adapt to inherited indices, and the F layer's indexer learns to produce a selection that generalizes across all the layers it serves. That joint adaptation is what makes even a dumb uniform pattern work fine after training — the layer-specific sensitivity just disappears.
Extra structural note from the overlap heatmap: the first layer is always kept as a full F layer (it has to seed the index cache, and early layers attend to a fundamentally different token subset than later ones — overlap with deep layers is ≤0.4). The strongest, most similar index regions cluster near the diagonal — i.e., a layer's indexer output looks most like its immediate neighbors, decaying as you move further away.
7. The failure case — and why it's actually an important negative result
Before landing on the greedy LM-loss search, the natural-seeming alternative was tried: pick the sharing pattern by directly maximizing cosine similarity between attention outputs, since that's cheaper to compute than running full LM-loss evaluations.
Build an N×N similarity matrix S[i][j] = cosine similarity between layer i's attention output using its own indexer vs. using layer j's indexer instead. Then solve for the best F/S assignment with dynamic programming:
dp[i][k] = max over j<i, c_j=F of:
dp[j][k-1] + Σ_{m=j+1}^{i-1} S[m][j]
— i.e., find the best previous F layer to "branch" from, accumulating similarity scores for every S layer that would reuse it. Solvable exactly by backtracking through the DP table.
This failed. The similarity-optimal pattern performed about the same as plain uniform interleaving — both clearly worse than the greedy LM-loss search. The reason is the core insight of the whole negative result:
Cosine similarity is a local metric — it only tells you how well-preserved a single layer's output is in isolation. It can't see how small token-selection mismatches propagate and compound through all the downstream layers. Two layers can have near-identical attention outputs (similarity ≈ 1) yet differ in exactly the handful of tokens that turn out to matter several layers later. Those subtle errors accumulate — and a layer-local similarity score has no way to predict that.
The LM-loss-based greedy search avoids this because it's a global, end-to-end signal — it measures the actual downstream effect of a sharing decision on the whole model's output, not just on one layer's local activation. This is the real lesson: local geometric similarity is a tempting cheap proxy, but for anything where errors compound across depth, you need an end-to-end metric.
My summary of the idea in one line
DSA's indexer recomputes "who matters" from scratch at every layer even though the answer barely changes between adjacent layers — IndexCache just caches that answer and reuses it, and the only real engineering question is which layers are allowed to skip recomputation, which can be found either by greedy search (no training) or learned directly via a provably-equivalent averaged-KL loss (with training).
if you found any mismatched detail in this post or want to contribute in paper or working code for indexcache please open issue on
github.link
Top comments (0)