DEV Community

Susilo harjo
Susilo harjo

Posted on • Originally published at susiloharjo.web.id

Lighthouse Attention: The Training-Time Hierarchy That Makes Quadratic Attention Practical Again

TL;DR:

  • 1.4–1.7× pretraining wall-clock speedup against dense SDPA at 32K–128K context — no inference overhead, no architectural changes.
  • Symmetric pyramid pooling compresses queries, keys, and values together — unlike every prior sparse method that only pools K/V — yielding an S²d attention call instead of NSd.
  • Two-stage training with a recoverability guarantee: Stage 1 trains with Lighthouse selection, Stage 2 recovers under dense SDPA — final loss beats the dense-from-scratch baseline.
  • Selection lives entirely outside the attention kernel, reusing stock FlashAttention on a contiguous gathered sub-sequence — no custom sparse kernels, no entangled selection logic.

Why Attention Gets Expensive — and Why Sparse Methods Haven't Solved It

FlashAttention solved the memory problem. It did not solve the compute problem. Scaled dot-product attention still scales as Θ(N²) — double the context, quadruple the FLOPs. At 512K context on a single NVIDIA B200, dense attention forward+backward burns enormous compute. Frontier models targeting million-token windows need 32 B200 GPUs for attention alone.

Existing sparse attention methods — NSA, HISA, DSA, MoBA — share two design choices that cause problems for pretraining. First, they pool only keys and values, keeping queries at full resolution. The attention call stays O(NSd) — still linear in N. Second, they embed selection logic inside custom attention kernels, meaning teams cannot reuse optimized FlashAttention kernels. The hardest problem: a training-time sparsifier must produce weights that still work as a competent dense-attention model at inference. Most prior methods never test this.

The Lighthouse Approach: Symmetric Pooling + External Selection

Lighthouse makes two decisive departures. Queries, keys, and values are all pooled symmetrically into an L-level pyramid — turning the attention call from O(NSd) to O(S²d) where S ≪ N. At 512K context, the forward pass becomes 21× faster.

Selection sits entirely outside the attention kernel. A four-stage pipeline — pyramid pooling, parameter-free ℓ₂-norm scoring, chunked-bitonic top-K selection, and FlashAttention on the gathered sub-sequence — wraps around standard SDPA without modifying it. The top-K step is deliberately non-differentiable: gradients flow only through the gathered Q, K, V entries into the projection weights, teaching the model to produce values useful when selected rather than scores good at selecting.

The chunked-bitonic top-K produces stratified selection, not strict global top-K — preventing attention collapse onto a narrow span. The coarsest pyramid level is always retained in full, guaranteeing every position gets at least one contributor.

Recovery Works — and the Model Gets Better

The acid test is recoverability. A 530M Llama-3-style decoder was trained on C4 at 98K context with Lighthouse in 26 of 30 layers (the first two and last two stay dense). At 16,000 total steps (~50.3B tokens), three Lighthouse→dense split points were tested. At each resume, loss spikes transiently by 1.12–1.57 nats, then recovers within ~1,000–1,500 SDPA steps and crosses below the dense baseline.

By step 16,000, all Lighthouse runs reach final losses of 0.6980–0.7102 vs. the dense baseline's 0.7237 — while using 22.5–27.0 wall-clock hours instead of 37.9. On Needle-in-a-Haystack retrieval (4K–96K context), Lighthouse with k=2048 matches or beats the dense baseline's retrieval rate. Context parallelism scales cleanly to 1M tokens across 32 B200 GPUs with no kernel changes.

Engineering Takeaways

Lighthouse is not a universal accelerator. At short contexts, pyramid overhead dominates and it provides no benefit. At 32K+ tokens, it is a drop-in pretraining optimization: no architectural changes, no inference penalty, no custom sparse kernels to maintain. The two-stage recipe is essential — skipping Stage 2 recovery leaves the model unable to perform dense attention. The optimal configuration (L=3, p=4, k=1536, projection-norm scorer) is well-characterized. Lighthouse integrates with existing context-parallelism infrastructure without sparse-aware collectives.

The one clear limitation: it is training-only. Autoregressive decoding presents one query at a time, violating the all-queries-co-occur assumption. For teams whose bottleneck is pretraining throughput at long context — which describes most frontier-model efforts — Lighthouse is a proven, recoverable speedup with no strings attached.


For the complete architectural breakdown — including the four-stage pipeline internals, the chunked-bitonic top-K mechanism that prevents attention collapse, and the full ablation grid across pyramid depths, top-K budgets, and scorer types — read the full analysis at susiloharjo.web.id:

🔗 https://susiloharjo.web.id/lighthouse-attention-nous-research/


Related on Susiloharjo:

Top comments (0)