DEV Community

StemSplit
StemSplit

Posted on • Originally published at stemsplit.io

I Exported HT-Demucs FT to ONNX in 2026 (4 Blockers Everyone Else Gave Up On)

You can't ship htdemucs_ft on iOS. You can't ship it on Android. You can't run it in a browser. PyTorch Mobile is a 2 GB install and a permission to break, MLX needs Apple Silicon, and the obvious answer — ONNX — has been "broken on htdemucs" across four open GitHub issues for three years.

I just shipped the first working ONNX export of htdemucs_ft. It runs in onnxruntime, is 1.31× faster than PyTorch on CPU, and is numerically equivalent to the original (max absolute difference: 0.000163 on drums, 0.000008 on vocals). All four specialist sub-models are on Hugging Face: StemSplitio/htdemucs-ft-onnx.

This is the engineering writeup — the 4 blockers that killed every prior attempt and the patches that beat them.

What You'll Learn

  • ✅ Why htdemucs is so hard to export (complex tensors, Python dynamism, fused C++ kernels)
  • ✅ How to replace torch.stft with Conv1d without losing accuracy (5 × 10⁻⁶ round-trip diff)
  • ✅ How to patch fractions.Fraction, random.randrange, and aten::_native_multi_head_attention
  • ✅ How to verify parity to 4 decimals before trusting the export
  • ✅ Pure numpy + onnxruntime inference — no PyTorch at runtime

Prerequisites

pip install "torch>=2.4,<2.5" "torchaudio>=2.4,<2.5" demucs onnx onnxruntime numpy soundfile
Enter fullscreen mode Exit fullscreen mode

No GPU required for export or inference. Tested on Apple M4 Pro and Linux x86_64. If you don't have demucs set up yet, follow the complete demucs local setup guide first.


Why doesn't htdemucs export to ONNX out of the box?

Short answer: Because the model uses four PyTorch features that no ONNX exporter has good answers for — complex tensors in the STFT, fractions.Fraction arithmetic in model.segment, random.randrange inside the cross-transformer, and the fused C++ aten::_native_multi_head_attention kernel.

Each one stops torch.onnx.export and torch.onnx.dynamo_export cold. You hit them in order; each new patch unblocks the next failure.


Blocker 1: complex64 STFT output

Short answer: Replace torch.stft with a Conv1d using sin/cos kernels that emit two real-valued channels.

The first op in HT-Demucs is:

z = torch.stft(x, n_fft=4096, hop_length=1024, window=hann,
               win_length=4096, normalized=True, center=True,
               return_complex=True, pad_mode="reflect")
Enter fullscreen mode Exit fullscreen mode

return_complex=True returns a complex64 tensor. ONNX's STFT op (opset 17+) does not support complex outputs; every downstream slice/transpose fails. The workaround:

import math, torch
import torch.nn as nn
import torch.nn.functional as F

def _make_stft_kernels(n_fft: int):
    n = torch.arange(n_fft, dtype=torch.float64)
    window = torch.hann_window(n_fft, periodic=True, dtype=torch.float64)
    norm = 1.0 / math.sqrt(n_fft)
    k = torch.arange(n_fft // 2 + 1, dtype=torch.float64).unsqueeze(1)
    angles = 2 * math.pi * k * n.unsqueeze(0) / n_fft
    cos = (window * torch.cos(angles)) * norm
    sin = (window * -torch.sin(angles)) * norm   # negative for forward STFT
    return cos.float().unsqueeze(1), sin.float().unsqueeze(1)

class RealSTFT(nn.Module):
    def __init__(self, n_fft=4096, hop_length=1024):
        super().__init__()
        cos, sin = _make_stft_kernels(n_fft)
        self.register_buffer("cos_kernel", cos)
        self.register_buffer("sin_kernel", sin)
        self.n_fft, self.hop_length = n_fft, hop_length

    def forward(self, x):
        x = F.pad(x.reshape(-1, 1, x.shape[-1]),
                  (self.n_fft // 2,) * 2, mode="reflect")
        real = F.conv1d(x, self.cos_kernel, stride=self.hop_length)
        imag = F.conv1d(x, self.sin_kernel, stride=self.hop_length)
        return torch.stack([real, imag], dim=1)    # (BN, 2, F, T) real
Enter fullscreen mode Exit fullscreen mode

Verify against the real thing before going further:

x = torch.randn(1, 343980)
ref = torch.stft(x, n_fft=4096, hop_length=1024,
                 window=torch.hann_window(4096), win_length=4096,
                 normalized=True, center=True, return_complex=True,
                 pad_mode="reflect")
ref_real = torch.stack([ref.real, ref.imag], dim=1)

stft = RealSTFT()
out = stft(x).squeeze(0)
print("max abs diff:", (out - ref_real).abs().max().item())   # ~5e-6
Enter fullscreen mode Exit fullscreen mode

5 × 10⁻⁶ is rounding noise. Use the same trick (ConvTranspose1d with conjugate kernels + overlap-add window-squared envelope) for the inverse STFT. Now every view_as_real / view_as_complex in the model's _magnitude and _mask methods can be rewritten to thread real-channel tensors through the whole forward pass.


Blocker 2: fractions.Fraction in model.segment

Short answer: Coerce to float before exporting.

Pretrained htdemucs_ft ships with model.segment = Fraction(39, 5) (= 7.8 seconds). Dynamo dies:

torch._dynamo.exc.Unsupported: call_function
UserDefinedClassVariable(<class 'fractions.Fraction'>)
Enter fullscreen mode Exit fullscreen mode

Fix:

from fractions import Fraction

if isinstance(model.segment, Fraction):
    model.segment = float(model.segment)   # 7.8
Enter fullscreen mode Exit fullscreen mode

Mathematically identical at inference. Trivial — but you don't get to the next blocker without it.


Blocker 3: random.randrange in the cross-transformer

Short answer: Monkey-patch the affected method to hardcode shift=0.

CrossTransformerEncoder._get_pos_embedding calls Python's random:

shift = random.randrange(self.sin_random_shift + 1)
Enter fullscreen mode Exit fullscreen mode

At inference, sin_random_shift = 0, so random.randrange(1) always returns 0 — a no-op. But neither exporter can see through random and both bail out. Patch the method directly:

import types
from demucs.transformer import CrossTransformerEncoder, create_sin_embedding

def _get_pos_embedding_no_random(self_, T, B, C, device):
    if self_.emb == "sin":
        return create_sin_embedding(T, C, shift=0, device=device,
                                    max_period=self_.max_period)
    raise RuntimeError(f"emb={self_.emb} not handled at export")

for m in model.modules():
    if isinstance(m, CrossTransformerEncoder):
        m._get_pos_embedding = types.MethodType(_get_pos_embedding_no_random, m)
Enter fullscreen mode Exit fullscreen mode

Blocker 4: aten::_native_multi_head_attention

Short answer: Replace nn.MultiheadAttention.forward with a plain Linear/bmm/softmax implementation.

Modern PyTorch short-circuits nn.MultiheadAttention.forward to a fused C++ kernel (_native_multi_head_attention) when its preconditions are met. That kernel has no ONNX symbolic at any opset:

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator
'aten::_native_multi_head_attention' to ONNX opset version 17 is not supported.
Enter fullscreen mode Exit fullscreen mode

Patch every MHA instance's forward with a drop-in that uses only ops with stable ONNX symbolics:

def _onnx_friendly_mha_forward(self_, query, key, value,
                                key_padding_mask=None, need_weights=True,
                                attn_mask=None, average_attn_weights=True,
                                is_causal=False):
    if self_.batch_first:
        query, key, value = (t.transpose(0, 1) for t in (query, key, value))
    tgt_len, bsz, embed_dim = query.shape
    head_dim = embed_dim // self_.num_heads

    if self_._qkv_same_embed_dim and torch.equal(query, key) and torch.equal(key, value):
        q, k, v = F.linear(query, self_.in_proj_weight, self_.in_proj_bias).chunk(3, dim=-1)
    else:
        # cross-attention path: three separate projections
        w_q, w_k, w_v = self_.in_proj_weight.chunk(3)
        b_q, b_k, b_v = (self_.in_proj_bias.chunk(3)
                         if self_.in_proj_bias is not None else (None, None, None))
        q = F.linear(query, w_q, b_q)
        k = F.linear(key,   w_k, b_k)
        v = F.linear(value, w_v, b_v)

    q = q.contiguous().view(tgt_len, bsz * self_.num_heads, head_dim).transpose(0, 1)
    k = k.contiguous().view(-1,      bsz * self_.num_heads, head_dim).transpose(0, 1)
    v = v.contiguous().view(-1,      bsz * self_.num_heads, head_dim).transpose(0, 1)

    attn = F.softmax(torch.bmm(q * head_dim ** -0.5, k.transpose(1, 2)), dim=-1)
    out  = torch.bmm(attn, v).transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    return self_.out_proj(out), None

for m in model.modules():
    if isinstance(m, nn.MultiheadAttention):
        m.forward = types.MethodType(_onnx_friendly_mha_forward, m)
Enter fullscreen mode Exit fullscreen mode

Parity vs the fused kernel: 1 × 10⁻⁶ max diff. Safe.


The export call

Short answer: Legacy torch.onnx.export at opset 17, dynamo=False. dynamo_export dies on the patches anyway; legacy works.

from demucs.pretrained import get_model

bag = get_model("htdemucs_ft")
drums_model = bag.models[0].eval().cpu()
# apply all 4 patches above to drums_model

n = int(float(drums_model.segment) * int(bag.samplerate))   # 343980
dummy = torch.randn(1, 2, n, dtype=torch.float32)

torch.onnx.export(
    drums_model,
    dummy,
    "htdemucs_ft_drums.onnx",
    input_names=["mix"],
    output_names=["stems"],
    opset_version=17,
    dynamo=False,
    do_constant_folding=True,
)
Enter fullscreen mode Exit fullscreen mode

316 MB per specialist. ~6.5 s export time on CPU. Passes onnx.checker.check_model. 24,765 nodes.

Repeat for bag.models[1] (bass), bag.models[2] (other), bag.models[3] (vocals). All four use the same architecture and patches.


Parity verification — the only acceptance test that matters

Short answer: Run the original PyTorch model and the ONNX model on the same fixed input, compute .abs().max(). Should be < 1e-3.

import numpy as np
import onnxruntime as ort

x = np.random.randn(1, 2, 343980).astype("float32")
sess = ort.InferenceSession("htdemucs_ft_drums.onnx",
                            providers=["CPUExecutionProvider"])
onnx_out = sess.run(["stems"], {"mix": x})[0]

torch_out = drums_model(torch.from_numpy(x)).detach().numpy()
print("max abs diff:", np.abs(onnx_out - torch_out).max())   # 1.63e-4
Enter fullscreen mode Exit fullscreen mode

Per stem, against the original PyTorch htdemucs_ft at fp32:

Stem max abs diff
drums 0.000163
bass 0.000011
other 0.000739
vocals 0.000008

All comfortably under the 1e-3 tolerance that fp32 reordering normally explains. SDR scores measured against MUSDB18-HQ are unchanged — if you want to verify yourself, query the 800-row leaderboard from pandas without re-running the benchmark.


Inference with zero PyTorch (pure numpy + onnxruntime)

Short answer: Load the four ONNX files, do overlap-add chunking, sum the per-model outputs.

import numpy as np
import onnxruntime as ort
import soundfile as sf

SOURCES = ["drums", "bass", "other", "vocals"]
CHUNK = 343980             # 7.8 s * 44100 Hz
HOP   = CHUNK // 4

sessions = {
    s: ort.InferenceSession(f"htdemucs_ft_{s}.onnx",
                            providers=["CPUExecutionProvider"])
    for s in SOURCES
}

def separate(mix, sr=44100):
    pad = CHUNK - (mix.shape[-1] % CHUNK)
    mix_p = np.pad(mix, ((0, 0), (0, pad)), mode="constant")
    out = {s: np.zeros_like(mix_p) for s in SOURCES}
    weight = np.zeros(mix_p.shape[-1])
    w = np.hanning(CHUNK).astype("float32")

    for start in range(0, mix_p.shape[-1] - CHUNK + 1, HOP):
        chunk = mix_p[:, start:start+CHUNK].astype("float32")[None]
        for s in SOURCES:
            y = sessions[s].run(["stems"], {"mix": chunk})[0][0]
            target_row = SOURCES.index(s)
            out[s][:, start:start+CHUNK] += y[target_row] * w
        weight[start:start+CHUNK] += w

    for s in SOURCES:
        out[s] = (out[s] / np.maximum(weight, 1e-8))[:, :mix.shape[-1]]
    return out

mix, sr = sf.read("song.wav", dtype="float32", always_2d=True)
stems = separate(mix.T, sr)
for s in SOURCES:
    sf.write(f"{s}.wav", stems[s].T, sr)
Enter fullscreen mode Exit fullscreen mode

Zero PyTorch. Zero MLX. Runs anywhere onnxruntime runs — iOS via onnxruntime-objc, Android via onnxruntime-android, browsers via onnxruntime-web.


Performance numbers

Short answer: ONNX Runtime CPU EP is 1.31× faster than PyTorch CPU. A single-specialist ONNX is ~5.7× faster than the full bag.

Apple M4 Pro, 3-minute song:

Backend Latency Notes
ONNX CPU EP — single specialist ~22 s Use this for vocal removers / drum extractors
ONNX CPU EP — full 4-stem bag ~88 s All stems
PyTorch CPU — full bag ~125 s Baseline
PyTorch MPS — full bag ~47 s Apple GPU
ONNX CUDA — NVIDIA L4 (extrapolated) ~6 s Server-side deployment

The single-specialist trick works because the htdemucs_ft bag is one-hot:

# the bag's per-model weight matrix
weights = [[1, 0, 0, 0],   # drums  = sub-model 0's drums output
           [0, 1, 0, 0],   # bass   = sub-model 1's bass output
           [0, 0, 1, 0],   # other  = sub-model 2's other output
           [0, 0, 0, 1]]   # vocals = sub-model 3's vocals output
Enter fullscreen mode Exit fullscreen mode

Sub-model 3's vocals output is the bag's vocals output, bit-exact. If you're building a vocal remover, ship sub-model 3 alone (~316 MB ONNX, ~75 MB quantized) instead of the full ~1.26 GB bag at identical per-stem quality.


Wrapping Up

The five new ONNX repos — all MIT-licensed, all parity-verified:

If you'd rather skip the deployment plumbing entirely, the StemSplit hosted vocal remover runs the exact same htdemucs_ft weights with credits, queueing, and a dashboard — same model, just hosted. The full long-form writeup on the StemSplit blog has the iOS/Swift, Android/Kotlin, and onnxruntime-web code samples as well.

Open an issue on any of the repos if you find a stem where your parity diff exceeds 1e-3 — would love to hear about it.

Top comments (0)