Building open-source solutions for my 100 Days of AI Agents challenge meant I needed to start looking at frameworks that scale better than standard NumPy and PyTorch. That inevitably led me to JAX.
Transitioning to JAX requires a bit of a paradigm shift. If you are used to the standard Python data science stack, JAX forces you to rewire how you think about array operations, memory, and hardware execution.
I spent today digging into the core mechanics, and I want to share my top 3 takeaways and the exact code snippets that made it click for me.
1. Immutability is a Feature, Not a Bug
This was my first major roadblock. In standard NumPy, if you want to change an element in an array, you just reassign it in place.
Python
import numpy as np
x = np.arange(10)
x[0] = 10
print(x) # Output: [10 1 2 3 4 5 6 7 8 9]
If you try the exact same thing in JAX, it screams at you: TypeError: JAX arrays are immutable.
JAX arrays (jax.Array) cannot be changed once created. This is a core design principle that enables JAX's functional programming nature and automatic differentiation. To update an array, JAX provides an indexed update syntax that returns an updated copy:
Python
**import jax.numpy as jnp
x = jnp.arange(10)
y = x.at[0].set(10)
print(y) # Output: [10 1 2 3 4 5 6 7 8 9]
print(x) # Output: 0 1 2 3 4 5 6 7 8 9.**
The Catch: This does create memory overhead since you are creating copies, but it completely eliminates the side-effects that make distributed computing a nightmare.
2. Native Hardware Awareness & Sharding
JAX arrays inherently know where they live. You don't have to jump through hoops to figure out if your data is on the CPU, GPU, or TPU.
By default, JAX pushes operations to the fastest available accelerator. Running this locally on my MSI Raider, I can easily inspect exactly where my array is stored using .devices():
Python
x.devices()
Output: {CpuDevice(id=0)}
More importantly, JAX arrays can be sharded across multiple devices for parallel execution. You can inspect this via the .sharding attribute:
Python
x.sharding
** Output: SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=device)**
It feels built from the ground up for modern hardware scaling.
3. The Magic of JIT Compilation
By default, JAX executes operations one at a time, in sequence (just like standard Python). But if you wrap a function with Just-In-Time (jax.jit) compilation, JAX optimizes the entire sequence of operations and runs them all at once.
I wrote a simple normalization function to test this:
Python
**from jax import jit
import jax.numpy as jnp
import numpy as np
def norm(X):
X = X - X.mean(0)
return X / X.std(0)
norm_compiled = jit(norm)
**
Generate some dummy data
**np.random.seed(22)
X = jnp.array(np.random.rand(100000, 10))
I benchmarked both functions using %timeit (adding .block_until_ready() to account for JAX's asynchronous dispatch). The results were immediate:
Standard Execution: 1.52 ms ± 16.3 μs per loop
JIT Execution: 1.16 ms ± 26.2 μs per loop**
Because the compiler knows the exact blueprint of the execution beforehand, it speeds things up significantly. The only limitation? Not all JAX code can be JIT compiled—it requires array shapes to be static and known at compile time.
What's Next?
This is just scratching the surface. My next deep dive is going to cover functional randomness (jax.random), automatic differentiation (jax.grad), and automatic vectorization (jax.vmap).
Has anyone else here made the jump to JAX recently? What was your biggest learning curve? Drop a comment below!
Top comments (2)
The numpy-to-JAX jump is a fun one. The aha is usually realizing how much you were leaving on the table with eager, loop-heavy numpy once jit + vmap + autodiff click into place. The mental shift that trips people: thinking in transformations over arrays instead of step-by-step imperative code, and respecting that JAX wants pure functions (the side-effect surprises are very real). Worth it for the speedups, but the functional discipline is the tax you pay. Not my daily layer, I'm more on the orchestration side with Moonshift, but the same "make it pure and composable" principle shows up everywhere good systems get built. What was your biggest aha, the jit speedup or vmap?
Spot on! My biggest 'aha' so far was definitely seeing that initial JIT speedup and realizing how strict immutability forces you to think differently about memory. I'm actually still figuring out the deeper mechanics of jit and vmap, but the 'functional discipline tax' you mentioned is already making total sense!