DEV Community

Elise Moreau
Elise Moreau

Posted on

Why your diffusion model is slow at batch size 1 (and what actually helps)

TL;DR: Single-image diffusion inference is bottlenecked by kernel launch overhead and attention memory traffic, not raw FLOPs. torch.compile with mode="reduce-overhead", a fused attention backend, and CFG batching get you most of the way before you reach for distillation.

I spend a lot of time looking at flame graphs from production diffusion pipelines. The pattern is almost always the same. The team profiles their model, sees 50 steps of a UNet or DiT, and assumes the path to lower latency is fewer steps. So they try LCM, then TCD, then some flavor of consistency distillation, and the quality drops in ways the product team notices.

The nuance here is that at batch size 1, your GPU is mostly idle. You are not compute-bound. You are launch-bound and memory-bound. Distillation helps eventually, but only after you have fixed the boring things.

What the profiler actually shows

Run a vanilla SDXL or a 1B-parameter DiT at 1024x1024, batch 1, on an H100. Capture a trace with torch.profiler and zoom into a single denoising step.

You will see something like this, roughly:

  • ~30-40% of wall time inside attention kernels
  • ~20-25% inside conv and linear layers
  • ~15-20% in layernorm, GELU, residual adds
  • The rest: kernel launch gaps, host-to-device syncs, Python overhead

That last bucket is the embarrassing one. On an H100 a kernel launch costs ~5 microseconds. A UNet step fires hundreds of kernels. A 50-step sample fires tens of thousands. You are paying for the privilege of dispatching work, not for the work itself.

To be precise: at batch 1, the same model at batch 8 often runs in less than 2x the wall time. That gap is your overhead bill.

Step one: torch.compile, but the right mode

The default torch.compile(model) call uses mode="default", which optimizes for compile time and flexibility. For inference you want:

import torch

unet = torch.compile(
    unet,
    mode="reduce-overhead",
    fullgraph=True,
    dynamic=False,
)
Enter fullscreen mode Exit fullscreen mode

reduce-overhead enables CUDA graphs, which replay a captured sequence of kernels in one launch. This is the single largest win for batch 1 diffusion on modern GPUs. In my measurements on PyTorch 2.3, this alone takes a 1024x1024 SDXL UNet step from ~42ms to ~28ms on H100. No quality change, no architecture change.

The catch: fullgraph=True will yell at you about any graph break. CFG implementations that branch on guidance_scale need rewriting. Custom samplers that touch .item() between steps will break CUDA graph capture. Plan for a day of fighting this.

Step two: pick an attention backend on purpose

PyTorch's scaled_dot_product_attention dispatches to one of several backends. The defaults are not always right for high-resolution diffusion.

Backend Best for Notes
FlashAttention-2 Long sequences, H100/A100 Default on most setups, good general choice
FlashAttention-3 H100 only ~1.5x faster than FA2 on Hopper, requires manual install
xFormers memory-efficient Older GPUs (V100, T4) Lower memory, slower than Flash on modern hardware
Math (fallback) Debugging only Never ship this

For DiT-style models at 2K resolution the sequence length per attention block hits 16K+ tokens. FA3 on H100 is a real difference there. I have seen 18% end-to-end latency drop on a 2B DiT just from switching FA2 to FA3 via torch.nn.attention.sdpa_kernel.

Step three: batch your CFG

Classifier-free guidance runs the model twice per step, once conditional and once unconditional. Most reference implementations call the UNet twice sequentially. Do not do this.

Concatenate the two prompts into one batch of 2, run one forward pass, split the output. On batch 1 this nearly halves your per-step latency because you were leaving the GPU idle anyway. The memory cost is negligible at typical inference resolutions.

This is a 3-line change and somehow lives in maybe 60% of the codebases I review.

Step four, only now: think about steps

After the above, a 50-step SDXL sample on H100 is in the 1.2-1.5 second range. If your product needs sub-second, then yes, look at LCM, Hyper-SD, or DMD2. But evaluate quality on your own data, not on the curated examples in the paper. Distilled models lose the most quality on the long tail of prompts your users actually send, particularly text rendering and fine compositional structure.

Trade-offs and limitations

CUDA graphs hate dynamic shapes. If your service accepts arbitrary aspect ratios you will recompile constantly. Either bucket aspect ratios into a small set of fixed shapes, or accept the warmup cost on cold paths.

reduce-overhead mode increases memory usage because it pins workspace buffers. On a 24GB consumer card this can push you over the edge with larger models. Profile before deploying.

FlashAttention-3 requires building from source against a specific CUDA version. If your deployment runs across mixed GPU generations, the version matrix becomes painful. Pick one backend per deployment target.

And the obvious one: none of this fixes a slow VAE decode. If you are generating at 2K, the VAE can dominate. Tiled VAE decoding or a distilled decoder like TAESD is a separate fight.

Further Reading

Top comments (0)