DEV Community

Cover image for A Quick Introduction to JAX
Shakudo
Shakudo

Posted on • Originally published at shakudo.io

A Quick Introduction to JAX

There are enough Python libraries out there that you’ll never understand or use them all. The more pertinent task is choosing the right one for your specific project. At Shakudo it’s pretty common for our team to begin using NumPy or another library, only to figure out halfway through that it’s not effective for our use case.

Shakudo provides data teams with an operating system for data stacks, and we’ve taken a liking to JAX lately for machine learning-and data processing. We’ll explain why in this quick intro to JAX.

What is JAX and what does it do?

Google’s JAX is a high-performance Python package, built to accelerate machine learning research. JAX provides a lightweight API for array-based computing - much like NumPy. It adds a set of composable function transformations, including for automatic differentiation, just-in-time (JIT) compilation, and automated vectorization and parallelization of your code. We’ll talk about those more later on.

JAX is executable on CPU, GPU, or TPU, with minor edits to your code making it easy to speed up big projects in a short amount of time. We’ve seen it used for some really cool projects including protein folding research, robotics control, and physics simulations.

Automatic differentiation is a procedure for computing derivatives that avoids the pitfalls of numerical (expensive and numerically unstable), and symbolic (exponential increase in the number of expressions) differentiation. The automatic differentiation procedure takes a function (program) and simplifies it into a sequence of primitive operations for which the derivative can be easily computed. This procedure is known as backpropagation.

Why use JAX?

Because JAX syntax is so similar to NumPy, with just a few code changes it can be used in projects where NumPy just isn’t cutting it performance-wise, or where you need some extra features that JAX supports. Data-heavy industries including machine learning, blockchain, and other data and compute-heavy use cases benefit from JAX’ improved performance. Maybe you're researching JAX because you’ve hit a wall in terms of scaling your data project.

Beyond speed, JAX is an all around great tool for prototyping because it’s easy to use if you already work with NumPy. It also has powerful features you won’t find in other ML libraries, and a highly familiar syntax for most Python developers.

Tests have shown that JAX can perform up to 8600% faster when used for basic functions - highly valuable for data-heavy application-facing models, or just for getting more machine learning experiments done in a day. Although most real-world applications won’t see this type of speed jump, it does show the potential value of switching.

Numpy vs JAX runtime comparison. JAX KDE Density Function is 1500x times faster

JAX is capable of these crazy-high speeds for the following reasons:

Vectorization: The method of vectorization enables processing multiple data as a single instruction. This method works for the cases where the same simple operation is applied on the entire data. Since most matrix operations involve applying the same operation on the rows and columns of the matrices, it makes it very amenable to vectorization, providing great speedups for linear algebra computations and machine learning.

JAX allows you to use jax.vmap to automatically generate a vectorized implementation of a function:

import plyvel
def read_chainstate():
    db = plyvel.DB('~/.bitcoin/chainstate', compression=None)

    for key, value in db.iterator(prefix=b'C'):
        tx_key = key[1:]
        print(f"key: {tx_key}"
        print(f"value: {value}")

    db.close()
Enter fullscreen mode Exit fullscreen mode
auto_batch_convolve = jax. vmap(convolve )
auto_batch_convolve(xs, ws)
# DeviceArray([[11., 20., 29. ],
#[11., 20. , 29. ]], dtype=float32)
Enter fullscreen mode Exit fullscreen mode

Code Parallelization: the process of taking a serial code that runs on a single processor and spreading the work across multiple processors. Which means it breaks the problem into smaller pieces so that all data can be processed simultaneously by the computer. This makes the process much more efficient than what it would be by waiting for the solution to one problem to solve the next one.

Automatic differentiation: a set of techniques to evaluate the derivative of a function, by exploiting sequences of elementary arithmetic operations. JAX differentiation is pretty straightforward:

def func(x):
    return x**2
d_func = grad(func)
Enter fullscreen mode Exit fullscreen mode

You can also repeatedly apply grad to get higher order derivatives. That is, we can get the second derivative of func by applying it again on d_func:

d2_func = grad(d_func)
Enter fullscreen mode Exit fullscreen mode

How JAX is built

JAX is built to use Accelerated Linear Algebra (XLA) and Just-in-Time Compilation (JIT). XLA is a domain-specific compiler for linear algebra that fuses together operations, meaning it allows you to skip intermediate results for overall improved speed. JAX uses XLA to compile and run NumPy programs on GPUs and TPUs without changes to your code. It traces your Python code to an intermediate representation, which is then just-in-time compiled.

With JIT, the first time the interpreter runs a method, it gets compiled to machine code so that subsequent executions will run faster. JIT is a simple function:

def funct(x):

    return x * (2 + x)
compiled_funct = jit(funct)
Enter fullscreen mode Exit fullscreen mode

Although it’s a powerful tool, it still doesn't work for every function. You can look to the JIT documentation to understand better about what it can and can’t compile.

How to Install and Import JAX

To install the CPU-only version of JAX, use the following:

pip install --upgrade pip
pip install --upgrade "jax[cuda]"
Enter fullscreen mode Exit fullscreen mode

And that’s it! Now you have the CPU support to test your code. To install GPU support, you’ll need to have CUDA and CuDNN already installed.

Finally, we can import the NumPy interface and the most important JAX functions using:

If you’ve already begun your project using NumPy, you can import it to JAX using the following, and use it to do the same operations as in NumPy:

import jax as jx
import numpy as np # Don't use this
import jax.numpy as jnp # Cool kids do this !
Enter fullscreen mode Exit fullscreen mode

Note that there are two restraints for your NumPy project to work:

You can only use pure functions. If you call your function twice, it has to return the same result, and you can’t do in-place updates of arrays:

array = jnp.zeros((2,2))
array[:, 0] = 1 # Won't work
array = jax.ops.index_update(array, jax.ops.index[:, 0], 1) # Better
Enter fullscreen mode Exit fullscreen mode

Random number generation is explicit and it uses a PRNG key:

key = jx.random.PRNGKey(1)
x = jx.random.normal(key=key, shape=(1000,))
Enter fullscreen mode Exit fullscreen mode

And there you go! You’re off on your first JAX project. If you want to try out all of this in an easily configurable workspace that pre-integrates all the open source tools and data frameworks you want, reach out to our team!

Authors: Stella Wu and Sabrina Aquino

Top comments (1)

Collapse
 
emmysteven profile image
Emmy Steven

Hello Shakudo,

Your article was well written, and I'm sure those in the ML community will find it resourceful.

Keep it up; it can only get better.