DEV Community

shinji shimizu
shinji shimizu

Posted on • Originally published at kotonia.ai

Cutting LTX-2 22B Peak VRAM by 40% with fp8_cast — and Why optimum-quanto Was a Trap

Introduction

LTX-2.3 is a video generation model from Lightricks that includes audio support. In A2V (Audio-to-Video) mode, it takes a single image + audio + prompt and generates lip sync, facial expressions, and head/hair motion all at once. Unlike lip-sync-only models like MuseTalk, it can animate an entire scene, which makes it a powerful tool for directing.

The catch: the base checkpoint is 22B parameters / 43 GB, and keeping it resident in bf16 with transformer × 2 stage burns ~86 GiB at idle. On an RTX PRO 6000 Blackwell with 96 GiB, that leaves almost nothing for the TTS / Ditto-TalkingHead / Qwen3-TTS-vLLM services running alongside it.

After testing quantization approaches, I got LTX-2's native fp8_cast to compress peak VRAM from 40 GiB → 24 GiB (A2V cold-start, 768×512 / 97f). Meanwhile, optimum-quanto int8/fp8 has a compatibility issue with the LTX-2 transformer and simply doesn't work. This post documents the debugging and the decisions made along the way.


Environment

  • GPU: NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition (96 GiB)
  • PyTorch: 2.9.1 + CUDA 12.8
  • Models: LTX-2.3 22B-dev (base) + 22B-distilled-lora-384 (stage_2) + Gemma-3-12B text encoder (bnb 4bit)
  • Deployment: A2V served via scripts/persistent_a2v_server.py --cold-start. Each request does build → run → free; idle is 0 GiB.

I use cold-start because A2V is called occasionally while conversation is the main workload, and it must coexist with TTS and Ditto. Details in a separate post.


Four Candidates

Looking at the LTX-2 codebase, there are actually two quantization paths:

1. LTX-2 Native: QuantizationPolicy

packages/ltx-core/src/ltx_core/quantization/policy.py:

@dataclass(frozen=True)
class QuantizationPolicy:
    sd_ops: SDOps | None = None              # weight transform at state dict load
    module_ops: tuple[ModuleOps, ...] = ()   # module rewrite after load

    @classmethod
    def fp8_cast(cls) -> "QuantizationPolicy":
        """Load weights as float8_e4m3fn, upcast to bf16 during forward"""
        return cls(
            sd_ops=TRANSFORMER_LINEAR_DOWNCAST_MAP,
            module_ops=(UPCAST_DURING_INFERENCE,),
        )

    @classmethod
    def fp8_scaled_mm(cls) -> "QuantizationPolicy":
        """FP8 scaled MM (requires tensorrt_llm)"""
Enter fullscreen mode Exit fullscreen mode

The implementation behind fp8_cast is Fp8CastLinear:

class Fp8CastLinear(torch.nn.Linear):
    def forward(self, input):
        w_up = _upcast_and_round(self.weight, input.dtype, ...)
        b_up = _upcast_and_round(self.bias, input.dtype, ...) if self.bias is not None else None
        return torch.nn.functional.linear(input, w_up, b_up)
Enter fullscreen mode Exit fullscreen mode

It uses the __class__ reassignment pattern to swap out instances. Weights are stored in fp8 and upcast to bf16 on every forward pass. The fp8 → bf16 cast cost is essentially noise on Blackwell.

2. optimum-quanto

The LTX-2 trainer package (packages/ltx-trainer) has a general-purpose quantization path using optimum-quanto, supporting int8-quanto / int4-quanto / fp8-quanto:

def quantize_model(model, precision, ...):
    if hasattr(model, "transformer_blocks"):
        _quantize_blockwise(model, ...)   # move one block at a time to GPU, quantize → freeze → CPU
    else:
        quantize(model, weights=..., exclude=EXCLUDE_PATTERNS)
        freeze(model)
    return model
Enter fullscreen mode Exit fullscreen mode

This looks like it could slot right in after _build_transformer().

Candidate Matrix

Mode Path Expected
fp8-cast LTX-2 native, sd_ops loads as float8_e4m3fn ~50% memory reduction, near-identical speed
fp8-scaled-mm LTX-2 native, requires tensorrt_llm Faster throughput
int8-quanto optimum-quanto, post-build ~50% memory reduction, speed ±
fp8-quanto Same, fp8 variant Potential to hit native FP8 on Blackwell

fp8-scaled-mm is out — no tensorrt_llm in this environment. I implemented the remaining three.


Stepping on a Mine with int8-quanto

The implementation is straightforward:

from ltx_trainer.quantization import quantize_model

transformer_1 = self.pipeline.stage_1._build_transformer()
transformer_1 = quantize_model(transformer_1, "int8-quanto", device=self.device)
self.transformer_stage_1 = _freeze(transformer_1)
Enter fullscreen mode Exit fullscreen mode

The server starts fine. Idle VRAM looks promising:

[load] stage_1 transformer (no distilled LoRA)
[quantize] stage_1 -> int8-quanto
[quantize] stage_1 done in 0.71s
[cuda] after stage_1 transformer: allocated=31.28GiB ...
[load] stage_2 transformer (with distilled LoRA)
[quantize] stage_2 -> int8-quanto
[quantize] stage_2 done in 0.52s
[cuda] after stage_2 transformer: allocated=49.40GiB ...
[server] A2V listening on http://127.0.0.1:8892
Enter fullscreen mode Exit fullscreen mode

Resident memory: 51.7 GiB (estimated 40% reduction from bf16's 86 GiB). Looks good.

Then the first /generate request:

[timing] prompt_encode=0.75s
[timing] audio_encode=0.39s
  0%|          | 0/30 [00:00<?, ?it/s]
[http] POST /generate 400
Enter fullscreen mode Exit fullscreen mode

Crashes at step 0/30. The error:

{"error": "linear(): argument 'weight' (position 2) must be Tensor, not NoneType"}
Enter fullscreen mode Exit fullscreen mode

Something is calling torch.nn.functional.linear(input, weight=None, bias=None). After quanto's freeze(), self.weight is being referenced as None somewhere in a Linear layer.

Why Does weight Become None?

Two rough hypotheses:

  1. LTX-2's Linear layers assume __class__ reassignment. Just like Fp8CastLinear, the pattern relies on keeping instance state intact while swapping the class-level forward. quanto's quantize()freeze() replaces nn.Linear with its own QLinear wrapper, and that replacement likely breaks the weight attribute reference somewhere in the process.

  2. EXCLUDE_PATTERNS doesn't work in the blockwise path. LTX-trainer's _quantize_blockwise pulls out one transformer_block at a time and calls quantize(block, exclude=EXCLUDE_PATTERNS). But EXCLUDE_PATTERNS uses glob patterns like patchify_proj, *adaln*, time_proj — these are relative to the whole model, not to a single block. They won't match relative paths inside a block, so layers that should be excluded end up getting quantized.

Either way, fixing this properly means reading through quanto's wrapper implementation plus all the forward paths in the LTX-2 transformer. The cost isn't worth it. I decided to cut my losses and switch to LTX-2 native fp8_cast.


Switching to fp8_cast

Three lines of code:

# Just pass the quantization policy when building the pipeline
pipeline_quantization = None
if transformer_quantization == "fp8-cast":
    from ltx_core.quantization import QuantizationPolicy
    pipeline_quantization = QuantizationPolicy.fp8_cast()

self.pipeline = A2VidPipelineTwoStage(
    ...,
    quantization=pipeline_quantization,
    ...
)
Enter fullscreen mode Exit fullscreen mode

fp8_cast downcasts weights to fp8 during the load phase. Since sd_ops hooks into state_dict loading, the 43 GB safetensors file gets fp8-converted during streaming load. Unlike quanto, which fully expands bf16 in memory before quantizing, peak VRAM never spikes — a nice property.

On startup:

[load] A2VidPipelineTwoStage builders (pipeline_quantization=QuantizationPolicy(sd_ops=...fp8_cast...))
...
[cuda] after stage_1 transformer: allocated=31.30GiB reserved=35.18GiB
[cuda] after stage_2 transformer: allocated=49.43GiB reserved=53.64GiB
[server] A2V listening on http://127.0.0.1:8892
Enter fullscreen mode Exit fullscreen mode

Resident allocated (51.7 GiB) is on par with int8-quanto, but reserved is only 53.6 GiB — dramatically lower (int8-quanto was 70.9 GiB). Lower reserved means more headroom for activations.

And the first /generate:

{
  "elapsed_seconds": 39.367,
  "peak_vram_gib": 57.918,
  "width": 768, "height": 512, "num_frames": 97
}
Enter fullscreen mode Exit fullscreen mode

It works. Back on track.


Benchmarks

Fixed conditions, persistent + fp8-cast, 3 resolutions × 3 runs each:

  • Image: 1024×512 portrait
  • Audio: 9.08-second Japanese sample generated with Irodori-TTS
  • Prompt: "A young woman speaks calmly to the camera in a softly lit room."
  • num_frames: 97 (= 4.04s @ 24fps)
  • seed: 42 fixed
Resolution Avg elapsed (s) Peak VRAM (GiB)
768×512 / 97f 39.84 57.92
1024×768 / 97f 66.71 59.06
1280×768 / 97f 84.02 58.30

Key observations:

  • Near-zero variance across 3 runs (fixed seed → byte-identical output mp4)
  • Peak VRAM is almost independent of resolution (57.9–59.1 GiB). Resident weights dominate; activation memory is only ~7 GiB
  • 1280×768 now works stably in persistent mode. This resolution was effectively impossible with bf16 persistent (~91 GiB peak)

Cold-Start Also Wins

Production runs in cold-start mode (A2V fires once or twice every few minutes, must coexist with TTS). Since fp8_cast policy is applied via sd_ops at pipeline construction time, it carries over naturally to per-request cold-start builds.

Cold-start + fp8-cast, single run (768×512 / 97f):

{
  "elapsed_seconds": 88.775,
  "peak_vram_gib": 23.901
}
Enter fullscreen mode Exit fullscreen mode
bf16 cold-start fp8-cast cold-start
Per-request time ~60–90s 88.8s (disk I/O bound, same order)
Peak VRAM ~40 GiB 23.9 GiB (~40% reduction)
Idle 0 GiB 0 GiB
Coexistence (TTS+Ditto+Qwen3+MuseTalk) Possible Comfortable (~30 GiB peak)

Speed is bottlenecked by disk I/O so fp8 doesn't hurt, but freeing up 16 GiB of peak headroom matters. Qwen3-TTS-vLLM (7 GiB) and MuseTalk warmup can now run concurrently with A2V generation without OOM.


Decision Matrix

Use case Recommended mode Rationale
Conversation-first, A2V occasionally cold-start + fp8-cast Idle 0, peak 24 GiB, comfortable coexistence with TTS/Ditto
Frequent A2V (batch generation, automated direction) persistent + fp8-cast Pay the 52 GiB resident cost, get 40s/req
1024+ resolution, quality focus persistent + fp8-cast 1280×768 stable (impossible with bf16 persistent)
Single GPU hosting everything cold-start + fp8-cast Persistent eats 52 GiB; depends on budget allocation across services

Production decision: cold-start + fp8-cast for now since conversation is primary. Switch to persistent fp8-cast if paying users drive enough A2V volume to justify the idle cost.


Summary

  • LTX-2 22B at bf16 idle (86 GiB) nearly monopolizes a single GPU. Quantization is close to mandatory.
  • optimum-quanto is incompatible with the LTX-2 transformer. It dies with F.linear(weight=None). Root cause is likely the __class__ reassignment pattern and/or EXCLUDE_PATTERNS not working correctly in the blockwise path. Not worth digging into.
  • LTX-2 native QuantizationPolicy.fp8_cast() is the right answer. fp8 at load time, bf16 upcast during forward. Three lines of code to enable.
  • cold-start + fp8-cast: peak 40 → 24 GiB. persistent + fp8-cast: 1280×768 becomes usable.
  • LTX-2 also has fp8_scaled_mm (requires tensorrt_llm) — worth trying if you're willing to set up TRT.

Appendix: Launch Command and Reproduction

Production cold-start + fp8-cast launch:

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True nohup uv run python scripts/persistent_a2v_server.py \
  --port 8892 \
  --checkpoint-path models/LTX-2.3/ltx-2.3-22b-dev.safetensors \
  --distilled-lora-path models/loras/ltx-2.3-22b-distilled-lora-384-1.1.safetensors \
  --spatial-upsampler-path models/LTX-2.3/ltx-2.3-spatial-upscaler-x2-1.1.safetensors \
  --gemma-root models/gemma-3-12b-it-qat-q4_0-unquantized \
  --output-dir outputs/a2v_server \
  --transformer-quantization fp8-cast \
  --cold-start \
  > /tmp/ltx_a2v_server.log 2>&1 &
Enter fullscreen mode Exit fullscreen mode

persistent_a2v_server.py is the official LTX-2 repo script extended for A2V. The --transformer-quantization fp8-cast flag was added via a local patch.

Implementation patch (key parts):

# scripts/persistent_a2v_server.py
pipeline_quantization = None
if transformer_quantization in ("fp8-cast", "fp8-scaled-mm"):
    from ltx_core.quantization import QuantizationPolicy  # late import: avoid circular reference
    pipeline_quantization = (
        QuantizationPolicy.fp8_cast()
        if transformer_quantization == "fp8-cast"
        else QuantizationPolicy.fp8_scaled_mm()
    )

self.pipeline = A2VidPipelineTwoStage(
    ...,
    quantization=pipeline_quantization,
    ...,
)
Enter fullscreen mode Exit fullscreen mode

from ltx_core.quantization import QuantizationPolicy at the top level causes a circular import with ltx_core.loader, so the late import is required.

Top comments (0)