You run Qwen3-VL on a single prompt. You record the output logprobs to full precision. Then you run the exact same prompt again, batched with 15 others. Same model, same weights, same GPU, same code. The logprobs are different.
Not catastrophically -- the top token usually agrees -- but the numbers have shifted at the seventh decimal place, and in an autoregressive loop, that hairline fracture propagates. By output position 8, 29% of your tokens have diverged.
You are not going crazy. vLLM silently changed the algorithm.
The Crime Scene
Setup: Qwen3-VL 2B on an NVIDIA H200. Identical prompts at BS=1 (one at a time) vs BS=16 (sixteen at once). VLLM_BATCH_INVARIANT=1 is enabled -- all GEMMs are deterministic via persistent Triton kernels. Yet:
| Metric | BS=1 vs BS=16 |
|---|---|
| Bitwise logprob match | 6.1% (30/490) |
| Top-1 token match | 78.6% |
| Semantic agreement | 100% |
The profiler gives the first clue: flash attention takes 5.15x longer per call in BS=16 (6.17ms vs 1.20ms). Same kernel name. Same call count (392). If it were the same algorithm processing more data, you'd expect it to scale with tokens -- not blow up 5x per call.
This is not a scaling problem. This is a different recipe.
What Attention Does (60 Seconds)
For each new token, attention looks back at every previous token, computes a relevance score for each, normalizes them (softmax), and takes a weighted average:
output = softmax(Q @ K^T / sqrt(d)) @ V
The critical observation: softmax involves summing over all previous tokens. If you have 640 tokens, that is 640 numbers being added. The order of that summation will matter shortly.
The Optimization That Changes Everything
Imagine 16 requests sharing the same 400-token system prompt. Without optimization, each scans those 400 tokens independently -- 6,400 redundant KV reads (80% waste).
vLLM's cascade attention splits the work:
- Prefix (done ONCE): Flash attention over the shared 400 tokens for all 16 queries. Produces partial output + Log-Sum-Exp (LSE) statistic.
- Suffix (per-request): Flash attention over each request's unique tokens (~100 each). Produces partial output + LSE.
- Merge: Combine using LSE-weighted rebalancing.
# flash_attn.py:1040 -- Cascade Implementation
def cascade_attention(output, query, key_cache, value_cache, ...):
# Step 1: Process shared prefix ONCE
prefix_output, prefix_lse = flash_attn_varlen_func(
q=query, k=key_cache, v=value_cache,
max_seqlen_k=common_prefix_len,
block_table=block_table[:1],
causal=False,
return_softmax_lse=True,
)
# Step 2: Process each request's unique suffix
suffix_output, suffix_lse = flash_attn_varlen_func(
q=query, k=key_cache, v=value_cache,
max_seqlen_k=max_kv_len - common_prefix_len,
block_table=block_table[:, num_common_kv_blocks:],
causal=True,
return_softmax_lse=True,
)
# Step 3: Merge -- THIS is where determinism breaks
merge_attn_states(output, prefix_output, prefix_lse,
suffix_output, suffix_lse)
Total KV reads: 400 + 16x100 = 2,000 (vs 8,000). A 4x bandwidth reduction.
Mathematically, this produces the identical result. But we don't live in the world of exact arithmetic.
Why 1 + 2 + 3 Does Not Equal 3 + 2 + 1
This is the section that explains everything.
IEEE 754 floating-point addition is not associative:
>>> a, b, c = 1e-7, 1.0, -1.0
>>> (a + b) + c # 1.1920928955078125e-07
>>> a + (b + c) # 1e-07
Same inputs. Same operations. Different answers. This is the IEEE 754 spec -- finite precision means rounding depends on order.
Connect this to attention:
-
Single-pass (BS=1):
softmax([t1, t2, ..., t640]) @ V-- one summation, one rounding chain -
Cascade (BS=16):
merge(softmax([t1,...,t512]) @ V_prefix, softmax([t513,...,t640]) @ V_suffix)-- two summations + LSE merge
The merge math:
# Simplified merge_attn_states logic
max_lse = max(prefix_lse, suffix_lse) # numerical stability
prefix_weight = exp(prefix_lse - max_lse)
suffix_weight = exp(suffix_lse - max_lse)
output = (prefix_weight * prefix_output
+ suffix_weight * suffix_output) \
/ (prefix_weight + suffix_weight)
# Math: IDENTICAL to single-pass
# IEEE 754: DIFFERENT by ~1e-7 (FP32) or ~1e-3 (FP16)
~1e-7 per element sounds negligible. But autoregressive generation feeds each output back as input through ~28 transformer layers. That 1e-7 compounds:
- Position 5: 17% token divergence
- Position 8: 29% token divergence
The Three Gates
When does vLLM activate cascade? Silently, based on three runtime conditions:
# flash_attn.py:962
def use_cascade_attention(common_prefix_len, query_lens, ...):
# Gate 1: Is shared prefix long enough?
if common_prefix_len < 256:
return False # "Not worth the overhead"
# Gate 2: Are there enough requests?
num_reqs = len(query_lens)
if num_reqs < 8:
return False # "Too few to benefit"
# Gate 3: Performance heuristic
cascade_time = cascade_waves * num_prefix_tiles
flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
return cascade_time < flash_decoding_time
BS=1 always fails Gate 2. It uses single-pass attention.
BS=16 with a system prompt passes all three. It uses cascade attention.
Your BS=1 benchmark is literally running a different algorithm than your BS=16 production system.
In our 490-pair experiment, exactly 30 matched bitwise -- always the last 3 per batch. As a batch of 50 processes, requests finish and leave. When fewer than 8 remain, Gate 2 closes. The last requests revert to single-pass and match BS=1 perfectly.The Smoking Gun: 30 Matching Samples
Batch Position
Active Requests
Cascade?
Matches BS=1?
1-47
50 down to 8
Yes
No
48
7
No
Yes
49
6
No
Yes
50
5
No
Yes
The Fork in the Code
The branch at flash_attn.py:673:
if not attn_metadata.use_cascade:
# SINGLE PASS: one call over all KV tokens
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
...
num_splits=attn_metadata.max_num_splits,
)
return output
# CASCADE: two calls + merge
cascade_attention(
output[:num_actual_tokens],
query[:num_actual_tokens],
key_cache, value_cache,
...
common_prefix_len=attn_metadata.common_prefix_len,
)
Same function signature. Completely different execution.
vLLM already knows these conflict. When VLLM_BATCH_INVARIANT=1 is set, it auto-disables cascade:
# vllm/config/vllm.py:994
if vllm_is_batch_invariant() and not self.model_config.disable_cascade_attn:
self.model_config.disable_cascade_attn = True
logger.warning_once(
"Disabling cascade attention when VLLM_BATCH_INVARIANT is enabled."
)
The Fix: Three Options
Option 1: Surgical (Recommended)
llm = LLM(model="your-model", disable_cascade_attn=True)
Forces single-pass for all batch sizes. 5-15% throughput loss for shared-prefix workloads.
Option 2: Remove Prefix Detection
llm = LLM(model="your-model", enable_prefix_caching=False)
Gate 1 never opens. Also disables other prefix caching benefits.
Option 3: Full Determinism
export VLLM_BATCH_INVARIANT=1
Replaces ALL cuBLAS GEMMs + auto-disables cascade. ~2.4x performance cost.
| Option | Cascade | GEMM Determinism | Cost |
|---|---|---|---|
| Default | Enabled | Non-deterministic | Baseline |
disable_cascade_attn=True |
Disabled | Non-deterministic | ~5-15% |
enable_prefix_caching=False |
Disabled | Non-deterministic | Medium |
VLLM_BATCH_INVARIANT=1 |
Auto-disabled | Deterministic | ~2.4x |
The Broader Lesson
Cascade attention is not a bug. It is a well-engineered bandwidth optimization. The issue is the silence.
This pattern recurs across GPU inference:
- Flash Decoding splits attention across thread blocks -- same associativity issue
- cuBLAS GEMM selects different tile sizes by matrix shape -- same op, different rounding
- torch.compile fuses differently between eager/compiled -- same model, different graph
Every time a framework says "mathematically equivalent," ask: equivalent in the reals, or in IEEE 754?
The ghost in the batch is not malicious. It is an optimization doing its job. But now you know it is there, you know when it activates, and you know how to control it.
Key Takeaways
- BS=1 and BS>=8 run different attention algorithms in vLLM. Single-pass vs cascade, by design.
- Cascade saves 4x memory bandwidth by processing shared prefixes once. The merge step introduces FP divergence.
- Three silent gates control activation: prefix >= 256 tokens, num_reqs >= 8, perf heuristic.
-
One flag fixes it:
disable_cascade_attn=TrueorVLLM_BATCH_INVARIANT=1. - "Mathematically equivalent" != "numerically identical." This applies across all GPU ML.
Key files: flash_attn.py:673 (fork), flash_attn.py:962 (gates), flash_attn.py:1040 (cascade), merge_attn_states.py (merge), vllm/config/vllm.py:994 (auto-disable)
Top comments (0)