DEV Community

Cover image for Running a JAX Program from Dart Using C++ FFI
Nik L.
Nik L.

Posted on

Running a JAX Program from Dart Using C++ FFI

🚀 Why Combine Dart and JAX for Machine Learning?

When building applications, selecting the right tools is crucial. You want high performance, easy development, and seamless cross-platform deployment. Popular frameworks offer trade-offs:

  • C++ provides speed but can slow down development.
  • Dart (with Flutter) is slower but simplifies memory management and cross-platform development.

But here’s the catch: most frameworks lack robust native machine learning (ML) support. This gap exists because these frameworks predate the AI boom. The question is:

How can we efficiently integrate ML into applications?

Common solutions like ONNX Runtime allow exporting ML models for application integration, but they aren’t optimized for CPUs or flexible enough for generalized algorithms.

Enter JAX, a Python library that:

  • Enables writing optimized ML and general-purpose algorithms.
  • Offers platform-agnostic execution on CPUs, GPUs, and TPUs.
  • Supports cutting-edge features like autograd and JIT compilation.

In this article, we’ll show you how to:

  1. Write JAX programs in Python.
  2. Generate XLA specifications.
  3. Deploy optimized JAX code in Dart using C++ FFI.

🧠 What is JAX?

JAX is like NumPy on steroids. Developed by Google, it’s a low-level, high-performance library that makes ML accessible yet powerful.

  • Platform Agnostic: Code runs on CPUs, GPUs, and TPUs without modification.
  • Speed: Powered by the XLA compiler, JAX optimizes and accelerates execution.
  • Flexibility: Perfect for ML models and general algorithms alike.

Here’s an example comparing NumPy and JAX:

# NumPy version
import numpy as np  
def assign_numpy():  
  a = np.empty(1000000)  
  a[:] = 1  
  return a  

# JAX version
import jax.numpy as jnp  
import jax  

@jax.jit  
def assign_jax():  
  a = jnp.empty(1000000)  
  return a.at[:].set(1)  
Enter fullscreen mode Exit fullscreen mode

Benchmarking in Google Colab reveals JAX’s performance edge:

  • CPU & GPU: JAX is faster than NumPy.
  • TPU: Speed-ups become noticeable for large models due to data transfer costs.

This flexibility and speed make JAX ideal for production environments where performance is key.


Read more


🛠️ Bringing JAX into Production

Cloud Microservices vs. Local Deployment

  • Cloud: Containerized Python microservices are great for cloud-based compute.
  • Local: Shipping a Python interpreter isn’t ideal for local apps.

Solution: Leverage JAX’s XLA Compilation

JAX translates Python code into HLO (High-Level Optimizer) specifications, which can be compiled and executed using C++ XLA libraries. This enables:

  1. Writing algorithms in Python.
  2. Running them natively via a C++ library.
  3. Integrating with Dart via FFI (Foreign Function Interface).

✍️ Step-by-Step Integration

1. Generate an HLO Proto

Write your JAX function and export its HLO representation. For example:

import jax.numpy as jnp  

def fn(x, y, z):  
  return jnp.dot(x, y) / z  
Enter fullscreen mode Exit fullscreen mode

To generate the HLO, use the jax_to_ir.py script from the JAX repository:

python jax_to_ir.py \
  --fn jax_example.prog.fn \
  --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2")]' \
  --constants '{"z": 2.0}' \
  --ir_format HLO \
  --ir_human_dest /tmp/fn_hlo.txt \
  --ir_dest /tmp/fn_hlo.pb
Enter fullscreen mode Exit fullscreen mode

Place the resulting files (fn_hlo.txt and fn_hlo.pb) in your app’s assets directory.


2. Build a C++ Dynamic Library

Modify JAX’s C++ Example Code

Clone the JAX repository and navigate to jax/examples/jax_cpp.

  • Add a main.h header file:
#ifndef MAIN_H  
#define MAIN_H  

extern "C" {  
  int bar(int foo);  
}  

#endif  
Enter fullscreen mode Exit fullscreen mode
  • Update the BUILD file to create a shared library:
cc_shared_library(  
   name = "jax",  
   deps = [":main"],  
   visibility = ["//visibility:public"],  
)  
Enter fullscreen mode Exit fullscreen mode

Compile with Bazel:

bazel build examples/jax_cpp:jax  
Enter fullscreen mode Exit fullscreen mode

You’ll find the compiled libjax.dylib in the output directory.


3. Connect Dart with C++ Using FFI

Use Dart’s FFI package to communicate with the C++ library. Create a jax.dart file:

import 'dart:ffi';  
import 'package:dynamic_library/dynamic_library.dart';  

typedef FooCFunc = Int32 Function(Int32 bar);  
typedef FooDartFunc = int Function(int bar);  

class JAX {  
  late final DynamicLibrary dylib;  

  JAX() {  
    dylib = loadDynamicLibrary(libraryName: 'jax');  
  }  

  Function get _bar => dylib.lookupFunction<FooCFunc, FooDartFunc>('bar');  

  int bar(int foo) {  
    return _bar(foo);  
  }  
}  
Enter fullscreen mode Exit fullscreen mode

Include the dynamic library in your project directory. Test it with:

final jax = JAX();  
print(jax.bar(42));  
Enter fullscreen mode Exit fullscreen mode

You’ll see the output from the C++ library in your console.


🎯 Next Steps

With this setup, you can:

  • Optimize ML models with JAX and XLA.
  • Run powerful algorithms locally.

Potential use cases include:

  • Search algorithms (e.g., A*).
  • Combinatorial optimization (e.g., scheduling).
  • Image processing (e.g., edge detection).

JAX bridges the gap between Python-based development and production-level performance, letting ML engineers focus on algorithms without worrying about low-level C++ code.


We’re building a cutting-edge AI platform with unlimited chat tokens and long-term memory, ensuring seamless, context-aware interactions that evolve over time.

It's fully free, and you can try it inside your current IDE, too.


Read more

Top comments (0)