How to replace 10+ PyTorch operations with a single GPU kernel while keeping the output identical to the original model – down to the last decimal.
If you’ve ever profiled a small Transformer on a consumer GPU, you know the pain: every decode step launches a swarm of tiny kernels, and Python dispatch overhead eats away your token rate. The solution is kernel fusion – but getting it right, especially with Rotary Position Embeddings, isn’t trivial.
This post walks through triton_fused_attention_v3.py, a self‑contained Triton kernel that fuses QKV projection + RoPE + KV cache write into a single launch for Qwen 2.5‑0.5B. It delivers a 4.5–5× speedup while maintaining cosine similarity = 1.000000 against the reference HuggingFace output. No special hardware needed – this runs on an RTX 3050 laptop GPU.
We’ll cover:
Why the kernel exists
The two design rules that guarantee bit‑perfect output
A line‑by‑line walkthrough of the kernel
The benchmark setup that proves both speed and accuracy
Why Fuse? The Death‑by‑a‑Thousand‑Kernels Problem
Here’s what happens when you call a standard PyTorch attention block:
python
q = linear(hidden, W_q) # kernel launch 1
k = linear(hidden, W_k) # launch 2
v = linear(hidden, W_v) # launch 3
q, k = rope(q, k) # 2 more launches (reshape + complex op)
then KV cache write, SDPA, O projection...
For Qwen 2.5‑0.5B (24 layers), that’s 240 separate GPU kernel launches per token. Each launch costs ~80 µs of CPU–GPU handshaking – not much alone, but 240 × 80 µs = 19 ms of pure overhead. The RTX 3050 can theoretically process a token in under 6 ms; we were spending over 30 ms because the GPU sat idle waiting for work.
The fix is to pack as many of those operations as possible into one kernel. And that’s exactly what fused_qkv_rope_v3 does.
What the Kernel Does (In One Shot)
Inside a single Triton program we do:
Tiled matrix‑vector multiply against a concatenated weight matrix [W_q ; W_k ; W_v] → Q, K, V
Apply RoPE on the Q and K heads using the local accumulator (no separate store‑reload)
Write rotated K and V directly into the persistent KV cache
Write rotated Q into an output buffer
After that, only two more operations remain: the attention softmax (done with PyTorch’s scaled_dot_product_attention) and the output GEMV. The whole decode step collapses from 10+ kernels to just 3.
And we don’t stop there: the script captures all three into a CUDA Graph, so subsequent tokens are fired with a single graph.replay() – virtually zero dispatch overhead.
The Two Critical Design Rules
Fusing RoPE inside a projection kernel is where things get slippery. RoPE rotates pairs of elements (2k, 2k+1) within each attention head. If you break those pairs across thread blocks, or let the partner value get rounded to FP16, you introduce numerical drift. This kernel uses two clean rules to avoid that completely.
Rule 1: BLOCK_M = head_dim – Never Split a RoPE Pair
Qwen 2.5‑0.5B has a head dimension of 64. So we set:
python
BLOCK_M = hd # 64
This ensures that every Triton program instance processes one entire attention head (or one head‑sized chunk of V). All 32 RoPE pairs inside that head are adjacent in the same local accumulator – zero cross‑block reads.
The grid is simply M_total // BLOCK_M = 18 blocks, which maps beautifully onto the GPU’s SMs.
Rule 2: Use an FP32 Scratchpad for RoPE Partner Values
To compute a rotation for element 2k you need the partner value at 2k+1 (and vice versa). In Triton you can’t just index acc[i+1] directly – you have to shift the data. The naive approach stores the accumulator to an FP16 buffer and reloads a shifted version. That FP16 round‑trip adds a small error that grows with sequence length.
Instead, the kernel uses a tiny FP32 temporary buffer allocated before the kernel launch:
python
fp32_tmp = torch.empty(q_dim + k_dim, device='cuda', dtype=torch.float32)
Inside the kernel, after computing QKV values into the FP32 accumulator, we store them without rounding:
python
tl.store(fp32_tmp_ptr + rows, acc, mask=mask) # Full FP32
Then we compute the partner indices (even reads rows + 1, odd reads rows - 1) and load from that FP32 buffer:
python
partner_rows = tl.where(is_even, rows + 1, rows - 1)
partner = tl.load(fp32_tmp_ptr + partner_rows, mask=mask) # FP32 load
Now both acc and partner are in full precision for the rotation:
python
roped = tl.where(is_even,
acc * cos_val - partner * sin_val,
partner * sin_val + acc * cos_val)
Only the final rotated value is cast to FP16 when written to the output or KV cache. This eliminates the precision loss entirely.
A Walk Through the Kernel Launch Parameters
The kernel is annotated with @triton.jit and takes a long list of parameters, but here’s what matters:
W_ptr: The concatenated QKV weight matrix of shape (q_dim + k_dim + v_dim, d).
x_ptr: The input hidden state of size d.
b_ptr: Concatenated bias of the same shape as the output rows.
q_ptr, k_cache_ptr, v_cache_ptr: Output buffers for Q and the KV cache.
cos_ptr, sin_ptr: Pre‑computed RoPE frequencies for the current position.
fp32_tmp_ptr: The FP32 scratchpad used only for RoPE partner access.
M: Total number of Q+K+V rows (1152 for this model).
head_dim: 64. half_hd: 32.
BLOCK_M: Set to 64 (head‑aligned).
BLOCK_K: Tile size along the input dimension (128).
The kernel iterates over the input dimension in chunks of BLOCK_K, accumulating the dot‑product results into registers.
After the accumulation, it checks whether the current block of rows belongs to Q, K, or V, applies RoPE if needed, and writes to the appropriate destination. The code uses masks (is_q, is_k, is_v) to steer the output without branches that diverge heavily.
How the Script Benchmarks and Proves Correctness
The main block in the file runs a thorough comparison for four cache lengths: 64, 128, 256, 512. For each length it:
Runs a standard() function that implements the exact HuggingFace attention logic in pure PyTorch. It times this over 300 iterations after 30 warmup runs, reporting the median kernel time in milliseconds.
Runs triton_v3() in eager mode – the three separate kernel launches – and times it identically. It computes the speedup and, crucially, calculates the cosine similarity between the standard output and the Triton output. The result is 1.000000 at every tested length.
Captures a CUDA Graph with the same three operations. After a few warmup replays, it times 500 graph replays and reports the even lower latency.
A representative output snippet:
text
--- cache_len=64 ---
Standard PyTorch: 2.123ms
Triton V3 (eager): 0.468ms
Speedup: 4.54x | cos: 1.000000
Triton V3 + Graph: 0.393ms (5.40x)
The graph‑captured version consistently adds an extra 15–20% improvement over the eager mode, because it eliminates the remaining Python transitions between the three fused kernels.
Key Takeaways for Triton Developers
Tile alignment is a correctness concern, not just a performance one. Setting BLOCK_M = head_dim ensured RoPE pairs were never split, which was the foundation for bit‑perfect output.
A tiny FP32 buffer (a few KB) can save you from silent precision drift. When the correctness of a rotation depends on a partner value, don’t let that value pass through FP16 – keep it full precision until the final write.
CUDA Graphs amplify the benefit of fusion. After fusing from 10 kernels to 3, a graph capture removes the last bit of Python overhead, squeezing out every bit of memory bandwidth.
Try the Code Yourself
The file is self‑contained. Install the dependencies and run:
bash
pip install torch triton transformers
python triton_fused_attention_v3.py
It will download Qwen 2.5‑0.5B (if not already cached), print the model dimensions, and run the four‑length benchmark. No other setup needed.
The full repository – including a CuPy‑based Windows kernel and a batched throughput exploit – is on GitHub, but the script we’ve explored here is the beating heart: a clean demonstration of how to fuse attention with RoPE while preserving 100% of the reference model’s output.
Did you spot an optimisation I missed? Or have your own fused‑kernel war story? Drop a comment – I’d love to hear what the dev.to GPU community is building.
Top comments (0)