DEV Community

Mayank Ketkar
Mayank Ketkar

Posted on

The Ghost in the Batch: How vLLM Silently Switches Algorithms

You run Qwen3-VL on a single prompt. You record the output logprobs to full precision. Then you run the exact same prompt again, batched with 15 others. Same model, same weights, same GPU, same code. The logprobs are different.

Not catastrophically -- the top token usually agrees -- but the numbers have shifted at the seventh decimal place, and in an autoregressive loop, that hairline fracture propagates. By output position 8, 29% of your tokens have diverged.

You are not going crazy. vLLM silently changed the algorithm.

The Crime Scene

Setup: Qwen3-VL 2B on an NVIDIA H200. Identical prompts at BS=1 (one at a time) vs BS=16 (sixteen at once). VLLM_BATCH_INVARIANT=1 is enabled -- all GEMMs are deterministic via persistent Triton kernels. Yet:

Metric BS=1 vs BS=16
Bitwise logprob match 6.1% (30/490)
Top-1 token match 78.6%
Semantic agreement 100%

The profiler gives the first clue: flash attention takes 5.15x longer per call in BS=16 (6.17ms vs 1.20ms). Same kernel name. Same call count (392). If it were the same algorithm processing more data, you'd expect it to scale with tokens -- not blow up 5x per call.

This is not a scaling problem. This is a different recipe.

What Attention Does (60 Seconds)

For each new token, attention looks back at every previous token, computes a relevance score for each, normalizes them (softmax), and takes a weighted average:

output = softmax(Q @ K^T / sqrt(d)) @ V
Enter fullscreen mode Exit fullscreen mode

The critical observation: softmax involves summing over all previous tokens. If you have 640 tokens, that is 640 numbers being added. The order of that summation will matter shortly.

The Optimization That Changes Everything

Imagine 16 requests sharing the same 400-token system prompt. Without optimization, each scans those 400 tokens independently -- 6,400 redundant KV reads (80% waste).

vLLM's cascade attention splits the work:

  1. Prefix (done ONCE): Flash attention over the shared 400 tokens for all 16 queries. Produces partial output + Log-Sum-Exp (LSE) statistic.
  2. Suffix (per-request): Flash attention over each request's unique tokens (~100 each). Produces partial output + LSE.
  3. Merge: Combine using LSE-weighted rebalancing.
# flash_attn.py:1040 -- Cascade Implementation
def cascade_attention(output, query, key_cache, value_cache, ...):
    # Step 1: Process shared prefix ONCE
    prefix_output, prefix_lse = flash_attn_varlen_func(
        q=query, k=key_cache, v=value_cache,
        max_seqlen_k=common_prefix_len,
        block_table=block_table[:1],
        causal=False,
        return_softmax_lse=True,
    )

    # Step 2: Process each request's unique suffix
    suffix_output, suffix_lse = flash_attn_varlen_func(
        q=query, k=key_cache, v=value_cache,
        max_seqlen_k=max_kv_len - common_prefix_len,
        block_table=block_table[:, num_common_kv_blocks:],
        causal=True,
        return_softmax_lse=True,
    )

    # Step 3: Merge -- THIS is where determinism breaks
    merge_attn_states(output, prefix_output, prefix_lse,
                      suffix_output, suffix_lse)
Enter fullscreen mode Exit fullscreen mode

Total KV reads: 400 + 16x100 = 2,000 (vs 8,000). A 4x bandwidth reduction.

Mathematically, this produces the identical result. But we don't live in the world of exact arithmetic.

Why 1 + 2 + 3 Does Not Equal 3 + 2 + 1

This is the section that explains everything.

IEEE 754 floating-point addition is not associative:

>>> a, b, c = 1e-7, 1.0, -1.0
>>> (a + b) + c    # 1.1920928955078125e-07
>>> a + (b + c)    # 1e-07
Enter fullscreen mode Exit fullscreen mode

Same inputs. Same operations. Different answers. This is the IEEE 754 spec -- finite precision means rounding depends on order.

Connect this to attention:

  • Single-pass (BS=1): softmax([t1, t2, ..., t640]) @ V -- one summation, one rounding chain
  • Cascade (BS=16): merge(softmax([t1,...,t512]) @ V_prefix, softmax([t513,...,t640]) @ V_suffix) -- two summations + LSE merge

The merge math:

# Simplified merge_attn_states logic
max_lse = max(prefix_lse, suffix_lse)  # numerical stability

prefix_weight = exp(prefix_lse - max_lse)
suffix_weight = exp(suffix_lse - max_lse)

output = (prefix_weight * prefix_output
        + suffix_weight * suffix_output) \
        / (prefix_weight + suffix_weight)

# Math: IDENTICAL to single-pass
# IEEE 754: DIFFERENT by ~1e-7 (FP32) or ~1e-3 (FP16)
Enter fullscreen mode Exit fullscreen mode

~1e-7 per element sounds negligible. But autoregressive generation feeds each output back as input through ~28 transformer layers. That 1e-7 compounds:

  • Position 5: 17% token divergence
  • Position 8: 29% token divergence

The Three Gates

When does vLLM activate cascade? Silently, based on three runtime conditions:

# flash_attn.py:962
def use_cascade_attention(common_prefix_len, query_lens, ...):
    # Gate 1: Is shared prefix long enough?
    if common_prefix_len < 256:
        return False  # "Not worth the overhead"

    # Gate 2: Are there enough requests?
    num_reqs = len(query_lens)
    if num_reqs < 8:
        return False  # "Too few to benefit"

    # Gate 3: Performance heuristic
    cascade_time = cascade_waves * num_prefix_tiles
    flash_decoding_time = cdiv(flash_decoding_ctas, num_sms)
    return cascade_time < flash_decoding_time
Enter fullscreen mode Exit fullscreen mode

BS=1 always fails Gate 2. It uses single-pass attention.

BS=16 with a system prompt passes all three. It uses cascade attention.

Your BS=1 benchmark is literally running a different algorithm than your BS=16 production system.

The Smoking Gun: 30 Matching Samples

In our 490-pair experiment, exactly 30 matched bitwise -- always the last 3 per batch. As a batch of 50 processes, requests finish and leave. When fewer than 8 remain, Gate 2 closes. The last requests revert to single-pass and match BS=1 perfectly.

Batch Position Active Requests Cascade? Matches BS=1?
1-47 50 down to 8 Yes No
48 7 No Yes
49 6 No Yes
50 5 No Yes

The Fork in the Code

The branch at flash_attn.py:673:

if not attn_metadata.use_cascade:
    # SINGLE PASS: one call over all KV tokens
    flash_attn_varlen_func(
        q=query[:num_actual_tokens],
        k=key_cache,
        v=value_cache,
        ...
        num_splits=attn_metadata.max_num_splits,
    )
    return output

# CASCADE: two calls + merge
cascade_attention(
    output[:num_actual_tokens],
    query[:num_actual_tokens],
    key_cache, value_cache,
    ...
    common_prefix_len=attn_metadata.common_prefix_len,
)
Enter fullscreen mode Exit fullscreen mode

Same function signature. Completely different execution.

vLLM already knows these conflict. When VLLM_BATCH_INVARIANT=1 is set, it auto-disables cascade:

# vllm/config/vllm.py:994
if vllm_is_batch_invariant() and not self.model_config.disable_cascade_attn:
    self.model_config.disable_cascade_attn = True
    logger.warning_once(
        "Disabling cascade attention when VLLM_BATCH_INVARIANT is enabled."
    )
Enter fullscreen mode Exit fullscreen mode

The Fix: Three Options

Option 1: Surgical (Recommended)

llm = LLM(model="your-model", disable_cascade_attn=True)
Enter fullscreen mode Exit fullscreen mode

Forces single-pass for all batch sizes. 5-15% throughput loss for shared-prefix workloads.

Option 2: Remove Prefix Detection

llm = LLM(model="your-model", enable_prefix_caching=False)
Enter fullscreen mode Exit fullscreen mode

Gate 1 never opens. Also disables other prefix caching benefits.

Option 3: Full Determinism

export VLLM_BATCH_INVARIANT=1
Enter fullscreen mode Exit fullscreen mode

Replaces ALL cuBLAS GEMMs + auto-disables cascade. ~2.4x performance cost.

Option Cascade GEMM Determinism Cost
Default Enabled Non-deterministic Baseline
disable_cascade_attn=True Disabled Non-deterministic ~5-15%
enable_prefix_caching=False Disabled Non-deterministic Medium
VLLM_BATCH_INVARIANT=1 Auto-disabled Deterministic ~2.4x

The Broader Lesson

Cascade attention is not a bug. It is a well-engineered bandwidth optimization. The issue is the silence.

This pattern recurs across GPU inference:

  • Flash Decoding splits attention across thread blocks -- same associativity issue
  • cuBLAS GEMM selects different tile sizes by matrix shape -- same op, different rounding
  • torch.compile fuses differently between eager/compiled -- same model, different graph

Every time a framework says "mathematically equivalent," ask: equivalent in the reals, or in IEEE 754?

The ghost in the batch is not malicious. It is an optimization doing its job. But now you know it is there, you know when it activates, and you know how to control it.

Key Takeaways

  1. BS=1 and BS>=8 run different attention algorithms in vLLM. Single-pass vs cascade, by design.
  2. Cascade saves 4x memory bandwidth by processing shared prefixes once. The merge step introduces FP divergence.
  3. Three silent gates control activation: prefix >= 256 tokens, num_reqs >= 8, perf heuristic.
  4. One flag fixes it: disable_cascade_attn=True or VLLM_BATCH_INVARIANT=1.
  5. "Mathematically equivalent" != "numerically identical." This applies across all GPU ML.

Key files: flash_attn.py:673 (fork), flash_attn.py:962 (gates), flash_attn.py:1040 (cascade), merge_attn_states.py (merge), vllm/config/vllm.py:994 (auto-disable)

Top comments (0)