The Problem Nobody Talks About
You profiled your LLM inference pipeline, found FlashAttention-2 as the bottleneck, and expected consistent 2-3x speedup over vanilla attention. Instead, the first batch takes 600ms while subsequent batches finish in 200ms. Your P99 latency metrics look terrible, users complain about cold-start delays, and you're stuck explaining why "the fast attention is slow."
This isn't a FlashAttention bug. It's CUDA kernel compilation happening at runtime.
When you call torch.nn.functional.scaled_dot_product_attention() with FlashAttention-2 enabled, PyTorch compiles optimized CUDA kernels on-demand based on your exact tensor shapes, dtypes, and GPU architecture. That first compilation can take 300-500ms. Every cold start—new container, model reload, shape change—triggers recompilation.
Why Runtime Compilation Exists
Continue reading the full article on TildAlice

Top comments (0)