DEV Community

TildAlice
TildAlice

Posted on • Originally published at tildalice.io

FlashAttention-2 Warmup: Fix 3x Slower First Batch

The Problem Nobody Talks About

You profiled your LLM inference pipeline, found FlashAttention-2 as the bottleneck, and expected consistent 2-3x speedup over vanilla attention. Instead, the first batch takes 600ms while subsequent batches finish in 200ms. Your P99 latency metrics look terrible, users complain about cold-start delays, and you're stuck explaining why "the fast attention is slow."

This isn't a FlashAttention bug. It's CUDA kernel compilation happening at runtime.

When you call torch.nn.functional.scaled_dot_product_attention() with FlashAttention-2 enabled, PyTorch compiles optimized CUDA kernels on-demand based on your exact tensor shapes, dtypes, and GPU architecture. That first compilation can take 300-500ms. Every cold start—new container, model reload, shape change—triggers recompilation.

sunflower, slower, petals, flora, yellow, nature, leaves

Photo by Kapa65 on Pixabay

Why Runtime Compilation Exists


Continue reading the full article on TildAlice

Top comments (0)