DEV Community

Shixin Zhang
Shixin Zhang

Posted on

Why JAX Is a Much Better Backend for Quantum Circuit Simulation Than PyTorch

Modern quantum circuit simulation is not just “machine learning with complex tensors.” It involves irregular tensor contractions, sparse operators, statevector transformations, and automatic differentiation through all of them. This makes backend choice unusually important. A backend that is excellent for standard neural-network layers may still be a poor fit for general quantum simulation workloads.

We benchmarked this with a simple VQE workload for the 1D transverse-field Ising
model as in the script,

H = -sum_i Z_i Z_{i+1} - sum_i X_i,
Enter fullscreen mode Exit fullscreen mode

using 20 qubits, 10 ansatz layers, complex64 precision, and one NVIDIA RTX 5090 GPU.

Results

Backend Compile / Warmup Value+Grad Runtime
TensorCircuit-NG, JAX backend 53.53 s 0.0265 s
TensorCircuit-NG, PyTorch backend 0.48 s 0.3299 s
TorchQuantum, optimized implementation than default 0.81 s 0.4172 s

The JAX backend is about 12.4x faster than TensorCircuit-NG’s PyTorch backend and about 15.7x faster than TorchQuantum for the post-compilation value-and-gradient step.

The compile time tells the other half of the story: JAX pays a much larger upfront XLA compilation cost. But after compilation, XLA produces a far more effective execution plan for this quantum simulation workload. This is exactly the tradeoff we want in VQE, QAOA, time evolution, and many other iterative algorithms: pay once, run many times.

Why This Happens

Quantum circuit simulation stresses a backend differently from ordinary deep learning. The workload mixes tensor-network contraction, sparse Hamiltonian application, and reverse-mode differentiation. JAX/XLA is designed to see the whole computation and optimize it aggressively as a compiled program on the target device.

PyTorch, in contrast, is strongest where the workload resembles standard neural network layers. For more general tensor programs, especially tensor-network-like simulation code, the compiler stack is less aggressive and less predictable.
In this benchmark, the same TensorCircuit-NG algorithm is more than an order of magnitude faster on JAX than on PyTorch after compilation.

A Note on TorchQuantum

We also compared against TorchQuantum as a representative PyTorch-native quantum circuit package. To make the comparison generous, we did not use its generic Pauli-string expectation path. That built-in route tends to materialize dense Pauli operators and is slow and not scalable. Instead, we implemented a TFIM-specific expectation directly extracted from state:

  • ZZ terms are evaluated from probabilities and precomputed sign tensors.
  • X terms are evaluated by flipping the state axis and taking an inner product.

This is already a substantial low-level optimization Even with that help, TorchQuantum remains slower than TensorCircuit-NG on the JAX backend by about 15.7x. And even if you prefer PyTorch backend, PyTorch backend from TensorCircuit-NG is still a better choice in terms of both warm-up and run times.

Takeaway

The lesson is not merely that one package is faster than another. The deeper point is that backend architecture matters. Quantum simulation benefits from a compiler that can optimize a whole differentiable tensor program, not just a collection of familiar machine-learning layers.

For TensorCircuit-NG, the JAX backend gives exactly that: a high-level quantum programming interface backed by XLA’s aggressive compilation. The result is a backend that is not only elegant for research code, but also dramatically faster for real differentiable quantum simulation workloads.

Top comments (0)