DEV Community

Elise Moreau
Elise Moreau

Posted on

torch.compile recompiled our SDXL UNet 38 times in production

TL;DR: torch.compile gave us a 2.3x speedup on our SDXL pipeline in benchmarks, then quietly recompiled 38 times across the first 100 production requests because every customer uploads a product photo at a different resolution. The fix wasn't turning compile off. It was understanding what counts as a guard, bucketing inputs to fixed shapes, and reading the recompilation logs PyTorch 2.3 gives you for free.

The benchmark that lied to me

At Photoroom we run diffusion models for product photography. Someone uploads a sneaker on a kitchen table, and the model gives it a clean studio background. The UNet is the heavy part, so when PyTorch 2.3 promised free speedups through torch.compile, I spent a week wiring it in.

The benchmark looked great. Fixed 1024x1024 input, batch size 4, an A10G. 2.3x faster than eager mode after warmup. I shipped it to a 5% canary.

p99 latency went up. Not by a little. Some requests took 70 seconds longer than before the change.

What a guard actually is

torch.compile traces a graph for a specific set of input properties. Tensor shapes, dtypes, certain scalar values, device. Dynamo wraps that graph in what it calls guards, which are cheap runtime checks that say "this compiled kernel is valid only if the next input matches". Miss a guard, and it recompiles.

Compiling the SDXL UNet takes 40 to 90 seconds on an A10G. That cost is fine once, at startup. The nuance here is that it happens lazily, inside the first request that violates a guard. So the recompile lands in the middle of a customer waiting for their image.

And product photos do not have a fixed shape. Phones shoot 3024x4032, someone crops to 800x600, a Shopify export is 1200x1200. Every new resolution is a new shape, a new guard miss, another recompile mid-request.

Watching the recompiles happen

The thing that saved me was not a profiler. It was one environment variable.

TORCH_LOGS="recompiles" python serve.py
Enter fullscreen mode Exit fullscreen mode

That prints a line every time Dynamo throws away a compiled graph, with the reason:

Recompiling function forward in unet.py:412
    triggered by the following guard failure(s):
    - tensor 'L['''sample''']' size mismatch at index 2.
      expected 128, actual 96
Enter fullscreen mode Exit fullscreen mode

Index 2 is the latent height. A 1024px image becomes a 128-wide latent, a 768px image becomes 96. Different shape, recompile. I counted 38 distinct recompiles before the cache stabilized, and it never fully stabilized because new resolutions kept arriving.

Three ways to stop it

I tested three approaches over a week against real traffic from our logs. Here's what held up.

Approach Recompiles after warmup Speedup kept Main cost
torch.compile(model, dynamic=True) 0 ~1.6x More general kernel, slower per step
Resolution bucketing 3 (warmed at boot) ~2.1x Padding pixels wasted through VAE
Fixed canonical resolution 0 2.3x Quality loss on extreme aspect ratios

torch.compile(model, dynamic=True) tells Dynamo to assume shapes vary from the start. No per-shape recompiles, but you pay with a more general kernel. We measured 1.6x instead of 2.3x. Honest, predictable, leaves speed on the table.

Bucketing won. We resize and pad every input so the long edge lands on one of {768, 1024, 1280}, then compile each bucket once at boot before the readiness probe goes green.

BUCKETS = [768, 1024, 1280]

def to_bucket(img):
    long_edge = max(img.height, img.width)
    target = min(BUCKETS, key=lambda b: abs(b - long_edge))
    return resize_and_pad(img, target)

# Warm all three at startup, before serving traffic
for b in BUCKETS:
    dummy = torch.zeros(1, 4, b // 8, b // 8, device="cuda")
    compiled_unet(dummy, t=torch.tensor([1.0], device="cuda"))
Enter fullscreen mode Exit fullscreen mode

Three compiles total, all before the pod takes traffic. Cache hits forever after.

The part that isn't the UNet

There's a small LLM step before the diffusion model even runs. It rewrites the user's text prompt into something the model handles better, turning "make it look nice" into a structured scene description. Those calls go out to an external provider, and when that provider has a bad minute the whole render stalls behind it.

We route that step through an AI gateway so a provider hiccup fails over to a backup instead of blocking. We use Bifrost for this, though several gateways do automatic fallback. It's one box in the pipeline diagram, not the interesting one, but it kept the prompt-rewrite step from becoming a single point of failure once compile fixed the UNet side.

Trade-offs and Limitations

Bucketing wastes compute on padding. An 800x600 image padded to 1024 pushes roughly 60% extra pixels through the VAE decode. For us that's worth a stable p99. For a catalog that's all square crops, it'd be pure overhead, and dynamic=True would be the saner call.

Warmup adds about 4 minutes to pod startup. Our Kubernetes readinessProbe has initialDelaySeconds set to wait it out, which slows autoscaling response during traffic spikes. You feel this when a flash sale hits.

The Inductor compile cache is per-process by default. Every new pod recompiles from scratch. You can point TORCHINDUCTOR_CACHE_DIR at a shared volume to persist it, but the cache is keyed loosely on environment, and a mismatched CUDA driver version across nodes gave us a silent fallback to eager once. Test that path before you trust it.

And none of this helps if your bottleneck is the VAE or the scheduler loop. Profile first. I burned two days compiling a UNet that was already 40% of wall-clock.

Further Reading

Top comments (0)