DEV Community

Elise Moreau
Elise Moreau

Posted on

The SDXL VAE overflow that decoded black images in fp16

TL;DR: The SDXL VAE decoder pushes activations past 65504, the max value fp16 can hold, so the last decode step overflows to inf and you get a fully black image. At Photoroom we hit this on roughly 1 in 600 product renders before we caught it. The fix is to upcast only the VAE, or swap in rescaled decoder weights, not to drop the whole pipeline to fp32.

We run SDXL-based pipelines for product photography. A customer uploads a sneaker on a kitchen table, we cut it out, then generate a clean studio background around it. Hundreds of thousands of renders a day, mostly on A10G and A100 GPUs, with the UNet in fp16 to keep the per-image latency under our budget.

The bug showed up as a thin stream of complaints. Black image. No error, no stack trace, no NaN warning in the logs. Just a 1024x1024 PNG of pure black where a render should be.

What was actually happening

I pulled 40 of the failing seeds and replayed them with hooks on every module in the VAE decoder. The UNet output was fine. Latents looked normal, values in the usual range. The decode was where it died.

To be precise, the overflow lives in the decoder's mid and up blocks. SDXL's VAE has a few residual layers where the post-convolution activations spike hard for certain inputs. fp16 tops out at 65504. I logged a max activation of 3.1e5 inside one of the up_blocks resblocks on a failing seed. Once a single value hits inf, the following GroupNorm propagates it across the whole feature map, and you decode garbage that clamps to black.

The nuance here is that it's input-dependent. Most latents never come close to the ceiling. High-contrast scenes with bright speculars, like a glossy bottle on white, are the ones that tip over. That's why our QA never saw it and production did.

import torch

# hook to catch the overflow as it happens
def watch(name):
    def hook(_, __, out):
        m = out.abs().max().item()
        if m > 6e4:  # fp16 max is 65504
            print(f"{name}: max activation {m:.1f}")
    return hook

for n, mod in pipe.vae.decoder.named_modules():
    mod.register_forward_hook(watch(n))
Enter fullscreen mode Exit fullscreen mode

That printout is what pointed me at the exact resblock instead of guessing.

The options we weighed

There's no single right answer here, and the trade-off is VRAM and latency against correctness. We measured four approaches on the same 500-seed batch.

Approach Fixes overflow VAE decode latency Extra VRAM Notes
Full pipeline fp32 Yes +210% ~2x Kills our latency budget
force_upcast VAE to fp32 Yes +18% +1.1 GB Only the VAE runs fp32
VAE in bf16 Yes +6% +0.1 GB Needs Ampere or newer
fp16-fix decoder weights Yes +0% +0 GB Rescaled weights, fp16 stays

Full fp32 was off the table. It doubled memory and blew past the latency we promise. The other three all hold up.

force_upcast is the diffusers default for a reason. It keeps the UNet in fp16 and runs only the VAE in fp32. One flag, and the overflow is gone because fp32 has the headroom.

from diffusers import AutoencoderKL, StableDiffusionXLPipeline

pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
)
pipe.vae.config.force_upcast = True  # VAE runs fp32, UNet stays fp16
Enter fullscreen mode Exit fullscreen mode

We landed on bf16 for the VAE on our Ampere fleet. bf16 has the same exponent range as fp32, so the 3.1e5 activation fits without issue, and the decode cost was 6% instead of 18%. On the older A10G boxes that don't get us the bf16 path we wanted, we use the rescaled fp16-fix decoder weights, which shift the activation magnitudes down so they never reach the ceiling in the first place.

One detail that bit us: if you call pipe.enable_vae_tiling() for large outputs, the tiling runs before the dtype upcast, so you still need the dtype right. Tiling reduces peak memory, it does not touch the numerical range.

Where the gateway fits

A side note, since people ask how the text side of this connects. Before the diffusion step, we rewrite the user's scene description into a cleaner prompt with an LLM, and we generate alt-text captions after. Those LLM calls go through Bifrost, an open-source gateway that gives us one OpenAI-compatible endpoint with automatic failover across providers. It has nothing to do with the VAE overflow. It just means when one provider has a bad afternoon, the caption step doesn't take the render pipeline down with it.

Trade-offs and limitations

bf16 is not a free win. It has the range of fp32 but only 8 bits of mantissa, fewer than fp16's 10, so you trade overflow safety for a little precision. On our renders the visible difference was nothing, but I would not assume that for every model. Measure SSIM against an fp32 reference before you ship.

The fp16-fix weights are a community rescaling, not an official release. They work well, and we validated them on 2000 renders, but you're trusting a third-party checkpoint. Pin the exact revision.

And none of this helps if your latents themselves are out of distribution. We saw two black images that were not VAE overflow at all, they were a bad LoRA producing extreme latents. The hook above tells you which failure you're looking at, so put it in your eval harness, not only in debugging.

Further Reading

Top comments (0)