Swapping dot‑product attention for RBF attention sounds like an architectural revolution. In Raphael Pisoni’s experiment, it turned out to be something stranger: a one‑line algebraic tweak that silently reproduces half the “mysterious” behaviors of modern Transformers — and breaks the hardware stack in the process.
TL;DR
- RBF attention is just dot‑product attention plus an explicit squared‑L2 penalty on keys; the “new” geometry is already latent in SDPA.
- Changing the metric forces you to confront everything your stack has hard‑coded about dot products: RoPE, attention sinks, fused kernels, even how you debug training.
- The right way to use RBF is as a diagnostic scalpel: borrow its inductive biases (norm penalties, distance‑based similarity) without paying the full engineering tax of a wholesale swap.
RBF Attention Is Just Dot‑Product + a Key L2 Penalty
Pisoni’s math trick is the key: start from a distance‑based score instead of a dot product,
[
\text{score}(q,k) = -\gamma\lVert q - k\rVert^2
]
and expand:
[
-\lVert q - k\rVert^2 = -\lVert q\rVert^2 + 2 q\cdot k - \lVert k\rVert^2
]
Softmax over keys is translation‑invariant, so for a fixed query the (-\lVert q\rVert^2) term is just a constant offset and disappears. What remains is:
[
\text{softmax}(-\gamma\lVert q - k\rVert^2) \equiv \text{softmax}(2\gamma\,q\cdot k - \gamma\lVert k\rVert^2)
]
So scaled RBF attention is algebraically:
- the usual dot‑product attention term (q\cdot k), plus
- a built‑in L2 penalty that depresses scores for big‑norm keys.
The first non‑obvious implication: you don’t need a new attention primitive to get the “RBF effect”. You can keep SDPA and add an explicit (-\lambda\lVert k\rVert^2) score bias per key. PyTorch’s FlexAttention exists precisely so you can inject this kind of score modification and still compile down to something FlashAttention‑like.
That means “distance‑based attention” is less a separate family of methods than a reparameterization of dot‑product attention’s existing geometry.
The real question becomes: what does that key‑norm penalty actually do to the model’s behavior?
Why Penalizing Key Norms Changes Attention Geometry
In standard dot‑product attention, the score is
[
\text{score}(q,k) = \frac{q\cdot k}{\sqrt{d_k}}
]
Geometrically, that’s “cosine similarity times magnitudes.” A key with a huge norm can “win” even if it isn’t well‑aligned with the query — the “magnitude bullying” Pisoni calls out.
We already know the network leans on this. The attention sinks phenomenon (e.g. StreamingLLM’s and punctuation sinks) shows models deliberately grow massive norms on a few tokens to hoover up probability mass when nothing else is relevant. That’s a hack that only exists because the scoring rule rewards large norms.
RBF attention quietly revokes that hack:
- Scores are upper‑bounded (max at distance 0), so you can’t arbitrarily crank norms to get unbounded logits.
- Large‑norm keys are penalized everywhere, not just for one query. Outlier keys pay a global tax.
This does two things to the “shape” of attention:
Locality becomes literal. “Similar” now really means “close in Euclidean space,” not “large projection along some axis.” You’re biasing the model toward clustered, metrically meaningful representations.
Selectivity shifts from K to Q. In SDPA, both queries and keys control sharpness via their norms. With RBF, the key’s ability to dominate is limited; selectivity is more in the query’s hands (or pushed into value scaling and downstream layers).
You can see the same story from the opposite direction: many recent tricks, from QK‑Norm to learnable query gains, are ways of manually managing these norms. RBF attention just bakes one particular management policy — “don’t let keys blow up” — directly into the scoring rule.
So RBF is not “attention but different.” It’s attention with one very specific inductive bias surfaced and made non‑optional: keys should live on a reasonably sized shell, not spike into infinity.
The Engineering Tax: RoPE, Fused Kernels, and the Hardware Lottery
On paper, swapping the metric is a two‑line refactor.
On actual GPUs, it’s a cascade failure.
Pisoni’s experiment ran into three kinds of breakage that are more interesting than the final TinyStories loss curves.
1. Memory: N×N the hard way
Naively computing pairwise distances with torch.cdist materializes the full (N\times N) distance matrix — instant OOM for any realistic context.
FlashAttention avoided this for dot‑products by treating attention as a tiled matmul with streaming softmax, never materializing full score matrices. As soon as your score is “dot product plus custom norm penalty,” you’re off the happy path.
The Triton RBF attention kernel in the repo has to re‑implement that tiling logic just to insert “(-\lVert k\rVert^2)” inside the fused loop. That’s effort most teams don’t want to repeat.
PyTorch’s FlexAttention is essentially a recognition that we’ve hit a software lottery: if your attention variant fits the built‑in kernels, it’s fast; if not, enjoy the OOMs. FlexAttention’s score_mod API is a way to write “2·qk − ||k||²” in Python and let torch.compile generate the fused kernel.
But until that kind of path is standard and stable, any deviation from SDPA is an explicit tax on memory‑efficient training and deployment.
2. Positional embeddings: RoPE is dot‑product‑native
RoPE works by rotating Q and K in a complex plane so that relative position becomes a phase shift in the dot product. It’s beautifully matched to cosine‑style similarity.
Euclidean distance doesn’t share those symmetries. Rotate Q and K jointly and the dot product is preserved; the distance is not. Pisoni finds that RoPE “interferes poorly” with RBF: you’re injecting position in a way that scrambles the distance geometry the score now cares about.
His fix is Subspace Sinusoidal Embeddings (SuSiE) — sinusoids in a subspace with a learnable scale, no rotations. That’s an architectural fork:
- Same model size and data,
- but completely different positional semantics,
- so you can’t cleanly attribute any performance change to “RBF vs dot product.”
If you’ve standardized on RoPE across a fleet of models, “just try RBF attention” actually means “also invent and validate a new positional scheme.”
3. Sinks and register tokens: rebuilding a lost behavior
Once you penalize key norms, the model loses its cheap attention sinks. That’s not just cosmetic; it affects optimization. Early in training, it’s often safer for a query to “look away” to a generic token than to make a sharp but wrong choice.
RBF attention removes that escape valve, so Pisoni introduces Register Tokens:
- Learned tokens prepended to the sequence,
- zero‑initialized so their keys live at the origin,
- giving the model a well‑behaved place to dump attention mass.
Again, this isn’t “RBF vs SDPA” in isolation. It’s a new architectural component whose presence or absence may matter more than the metric swap.
When you add up custom position encodings, register tokens, and custom kernels, the metric change has turned into an entire mini‑model family.
That’s the real engineering lesson: changing a single algebraic primitive in attention forces you to reveal how much of the stack — from RoPE to FlashAttention — is contingent on that choice.
What RBF Actually Buys You — And When To Use It
So is RBF attention “better”?
On TinyStories‑scale experiments, Pisoni reports slightly better convergence than SDPA, but nothing like a phase change. Given the number of knobs moved (metric, position, extra tokens, kernel), that’s almost a wash.
The more interesting conclusion is different: RBF attention is most valuable as a probe, not a permanent replacement.
Three concrete uses:
-
Diagnose norm pathologies. If RBF attention stabilizes training or fixes weird behaviors, you’ve discovered your model is over‑relying on key norms — perhaps for sinks, perhaps as a crude multi‑head routing signal. You can then:
- add explicit key‑norm penalties inside standard SDPA (via FlexAttention),
- or adjust Q/K normalization and gains rather than rewrite attention wholesale.
-
Borrow the inductive bias without the kernel tax. Instead of committing to RBF everywhere, you can:
- use a key‑norm score penalty only in early layers,
- or only in modules where you know magnitude bullying is harmful (e.g., cross‑attention over noisy memory). You keep FlashAttention and RoPE for the rest.
-
Exploit kernelization where it matters. RBF kernels have product‑of‑features representations; works like Performers show how to linearize RBF attention at inference using random features. That suggests a hybrid:
- train with conventional SDPA plus an RBF‑style bias (using flexible kernels),
- then distill or approximate certain heads with linearized RBF for latency‑sensitive serving.
The pattern here is familiar from other transitions (e.g., from naive PyTorch ops to custom CUDA, or from dense matmuls to sparsity): the “new” method is initially most powerful as a way to see what the old method was secretly doing, not as a wholesale drop‑in.
With RBF attention, that secret is simple: dot‑product attention is already a geometry over norms, we just hid the policy choice (how much to reward big keys) behind opaque training dynamics and hardware‑driven assumptions.
Key Takeaways
- RBF attention is algebra, not magic. It’s mathematically equivalent to dot‑product attention with a per‑key squared‑L2 penalty; you can implement that directly in SDPA‑compatible kernels.
- The geometry you change is norm geometry. Penalizing key norms removes “magnitude bullying” and built‑in attention sinks, biasing the model toward more metrically local behavior.
- The engineering cost is real. RoPE, FlashAttention, and much of the current stack are tuned for dot products; swapping metrics drags in new positional schemes, register tokens, and custom Triton attention kernels.
- Treat RBF as a diagnostic scalpel. Use it to reveal and then selectively import useful inductive biases, instead of betting your entire model family on a metric swap.
- Tooling like FlexAttention is the enabler. The more we can express “dot‑product plus bias” in high‑level code and compile to fused kernels, the less the hardware lottery will dictate which attention geometries we’re allowed to explore.
Further Reading
- Scaled RBF Attention: Trading Dot Products for Euclidean Distance — Pisoni’s detailed on RBF attention, including math, Triton kernel, and TinyStories experiments.
- 4rtemi5/rbf_attention — GitHub repo with the Triton attention kernel, model configs, and implementation details.
- FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention — PyTorch’s proposal for expressing custom score modifications while retaining fused‑kernel performance.
- Rethinking Attention with Performers — Introduces product‑of‑features approximations for kernels like RBF, relevant if you want linear‑time approximate attention.
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — The canonical description of IO‑aware, fused SDPA kernels that current hardware and software stacks are optimized around.
- Zero‑Copy Graph Engine for Memory‑Efficient Training — How kernel choices and data movement shape the real cost of attention variants.
- CUDA Agents and the Custom Kernel Tradeoff — When writing your own Triton or CUDA kernels pays off — and when you should lean on compiler‑generated ones instead.
In a world tuned for dot products, RBF attention is less a new engine and more a diagnostic port: plug it in, see how your model behaves when you turn off norm‑based hacks, then decide which of those hacks you actually wanted all along.
Originally published on novaknown.com
Top comments (1)
It's frustrating that dev.to doesn't support mathjax :(