Lately I became a contributor for the Bonsai project, where I translated EfficientNet , U-Net and a Variational Autoencoder (VAE) into JAX code.
JAX is a super fast NumPy-based ML framework with automatic differentiation, providing high-performance and scalability essential for modern machine learning research. Its focus on functional programming and composability aligns perfectly with the Bonsai project’s mission to offer simple, hackable, and concise implementations of popular models. This approach not only lowers the barrier to entry for JAX but also promotes academic innovation. Gemini is trained on JAX.
Here I will use Antigravity IDE , to develop a VAE and make inference. We will leverage the efficiency and speed of JAX , combined with the convenience of a modern cloud development environment, to walk through the entire development process of this generative model.
GitHub - jax-ml/bonsai: Minimal, lightweight JAX implementations of popular models.
The implementation follows this paper:
https://arxiv.org/abs/1312.6114
We start with two files: modeling.py and params.py. These two files, define the structure and initialization logic for the Variational Autoencoder (VAE) model within the JAX Bonsai project using the Flax NNX module system.
Flax NNX (Neural Networks JAX) is a new, simplified API within the Flax ecosystem designed to make creating, debugging, and analyzing neural networks in JAX easier and more intuitive. aims to bridge the gap between JAX’s functional programming core and the object-oriented style familiar to PyTorch or Keras users.
In essence, Flax NNX allows researchers to leverage JAX’s performance (automatic differentiation, JIT compilation, and hardware acceleration) while enjoying a more intuitive and flexible object-oriented experience.
The VAE Architecture
modeling.py
This file contains the core definitions for the VAE model components and the forward pass logic.
ModelCfg (Data Structure): This dataclass holds the hyperparameters for the VAE, such as the input_dim (e.g., 784 for a flattened 28x28 image), hidden_dims (the size of the intermediate layers), and the latent_dim (the dimensionality of the compressed latent space, z ).
Encoder (NNX Module): This module takes the input data ( x ) and maps it to the parameters of the latent distribution.
- It uses a sequence of fully-connected (Linear) layers with the ReLU activation function.
- The output layer is split into two separate linear layers, fc_mu and fc_logvar , which output the mean ( mu ) and log-variance ( log\sigma² or logvar ) of the latent Gaussian distribution, respectively.
Decoder (NNX Module): This module takes a sample from the latent space ( z ) and reconstructs the input data. It generally uses a mirrored architecture of the encoder (reversed hidden_dims). The final output, fc_out, produces the reconstruction logits, which are used to calculate the reconstruction loss (e.g., Binary Cross-Entropy for images like MNIST).
VAE (NNX Module): This is the main class that combines the Encoder and Decoder.
- reparameterize method: This is the crucial step in VAEs. It implements the reparameterization trick to sample the latent vector z from N(mu, sigma²) using a random noise vector ϵ ∼N(0, I):
__call__ method: This defines the VAE’s forward pass: input x goes through the Encoder ; the latent sample z is then passed to the Decoder for reconstruction.
import dataclasses
from typing import Sequence
import jax
import jax.numpy as jnp
from flax import nnx
@dataclasses.dataclass(frozen=True)
class ModelCfg:
"""Configuration for the Variational Autoencoder (VAE) model."""
input_dim: int = 784 # 28*28 for MNIST
hidden_dims: Sequence[int] = (512, 256)
latent_dim: int = 20
class Encoder(nnx.Module):
"""Encodes the input into latent space parameters (mu and logvar)."""
def __init__ (self, cfg: ModelCfg, *, rngs: nnx.Rngs):
self.hidden_layers = [
nnx.Linear(in_features, out_features, rngs=rngs)
for in_features, out_features in zip(
[cfg.input_dim] + list(cfg.hidden_dims), cfg.hidden_dims
)
]
self.fc_mu = nnx.Linear(cfg.hidden_dims[-1], cfg.latent_dim, rngs=rngs)
self.fc_logvar = nnx.Linear(cfg.hidden_dims[-1], cfg.latent_dim, rngs=rngs)
def __call__ (self, x: jax.Array) -> tuple[jax.Array, jax.Array]:
x = x.reshape((x.shape[0], -1))
for layer in self.hidden_layers:
x = nnx.relu(layer(x))
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
class Decoder(nnx.Module):
"""Decodes the latent vector back into the original input space."""
def __init__ (self, cfg: ModelCfg, *, rngs: nnx.Rngs):
# Mirrored architecture of the encoder
dims = [cfg.latent_dim] + list(reversed(cfg.hidden_dims))
self.hidden_layers = [
nnx.Linear(in_features, out_features, rngs=rngs)
for in_features, out_features in zip(dims, dims[1:])
]
self.fc_out = nnx.Linear(dims[-1], cfg.input_dim, rngs=rngs)
def __call__ (self, z: jax.Array) -> jax.Array:
for layer in self.hidden_layers:
z = nnx.relu(layer(z))
reconstruction_logits = self.fc_out(z)
return reconstruction_logits
class VAE(nnx.Module):
"""Full Variational Autoencoder model."""
def __init__ (self, cfg: ModelCfg, *, rngs: nnx.Rngs):
self.cfg = cfg
self.encoder = Encoder(cfg, rngs=rngs)
self.decoder = Decoder(cfg, rngs=rngs)
def reparameterize(self, mu: jax.Array, logvar: jax.Array, key: jax.Array) -> jax.Array:
"""Performs the reparameterization trick to sample from the latent space."""
std = jnp.exp(0.5 * logvar)
epsilon = jax.random.normal(key, std.shape)
return mu + epsilon * std
def __call__ (self, x: jax.Array, sample_key: jax.Array) -> tuple[jax.Array, jax.Array, jax.Array]:
"""Defines the forward pass of the VAE."""
mu, logvar = self.encoder(x)
z = self.reparameterize(mu, logvar, sample_key)
reconstruction = self.decoder(z)
return reconstruction, mu, logvar
Model Creation and Initialization
params.py
This file is responsible for instantiating the VAE model and preparing it for training or inference, potentially handling distributed execution.
- create_model function: This is the factory function for the VAE. It takes the model configuration (cfg), JAX random number generators (rngs), and an optional JAX device mesh for distributed systems. It initializes the VAE module, which automatically creates and initializes all the internal parameters (weights and biases) of the Linear layers using the provided rngs.
- Distributed Execution Logic: It uses nnx.split to separate the model graph/definition (graph_def) from the model parameters/state (state). It calculates the required sharding, how the parameters should be distributed across devices. It uses jax.device_put to place the state variables onto the devices according to the defined sharding strategy, preparing the model for large-scale distributed training (common in JAX/Flax). Then, it uses nnx.merge to combine the sharded state back with the graph definition.
import jax
from flax import nnx
from bonsai.models.vae import modeling as vae_lib
def create_model(
cfg: vae_lib.ModelCfg,
rngs: nnx.Rngs,
mesh: jax.sharding.Mesh | None = None,
) -> vae_lib.VAE:
"""
Create a VAE model with initialized parameters.
Returns:
A flax.nnx.Module instance with random parameters.
"""
model = vae_lib.VAE(cfg, rngs=rngs)
if mesh is not None:
graph_def, state = nnx.split(model)
sharding = nnx.get_named_sharding(model, mesh)
state = jax.device_put(state, sharding)
return nnx.merge(graph_def, state)
else:
return model
In summary, modeling.py builds the architecture of the VAE, and params.py is used to create an instance of that architecture and initialize its parameters.
If you are going to train it, you will need to define the loss function:
- Loss function: The total loss is the Negative Evidence Lower Bound (Negative ELBO), which the VAE aims to minimize:
Since we are minimizing, we flip the signs, making the reconstruction term positive and the KL term negative in the ELBO , or simply keeping both positive in the standard loss formulation you used:
Inference
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
import tensorflow_datasets as tfds
from flax import nnx
import tensorflow as tf
import sys
from pathlib import Path
bonsai_root = Path.home()
sys.path.insert(0, str(bonsai_root))
from bonsai.models.vae import modeling as vae_lib
from bonsai.models.vae import params as params_lib
Load and Preprocess Data
ds = tfds.load('mnist', split='test', as_supervised=True)
images_list = []
labels_list = []
for image, label in ds.take(10):
single_image = tf.cast(image, tf.float32) / 255.0
images_list.append(single_image.numpy())
labels_list.append(label.numpy())
image_batch = jnp.stack(images_list, axis=0)
Load Pretrained Weights
config = vae_lib.ModelCfg(
input_dim=28*28,
hidden_dims=(512,),
latent_dim=10,
)
rngs = nnx.Rngs(params=0, sample=1)
model_template = params_lib.create_model(cfg=config, rngs=rngs)
ckpt_dir = "/bonsai/bonsai/models/vae/tests/checkpoints"
checkpointer = ocp.PyTreeCheckpointer()
loaded_state_dict = checkpointer.restore(ckpt_dir)
graphdef, _ = nnx.split(model_template)
model = nnx.merge(graphdef, loaded_state_dict['params'], loaded_state_dict['other_vars'])
Reconstruct Input
@jax.jit
def reconstruct(model: vae_lib.VAE, batch: jax.Array, sample_key: jax.Array):
"""Encodes and decodes an image batch using the trained VAE."""
reconstruction_logits_flat, _, _ = model(batch, sample_key=sample_key)
reconstructed_probs_flat = jax.nn.sigmoid(reconstruction_logits_flat)
return reconstructed_probs_flat.reshape(batch.shape)
sample_key = rngs.sample()
reconstructed_images = reconstruct(model, image_batch, sample_key)
fig, axes = plt.subplots(2, 10, figsize=(15, 3.5))
for i in range(10):
# Plot original images on the first row
axes[0, i].imshow(image_batch[i, ..., 0], cmap='gray')
axes[0, i].set_title(f"Label: {labels_list[i]}")
axes[0, i].axis('off')
# Plot reconstructed images on the second row
axes[1, i].imshow(reconstructed_images[i, ..., 0], cmap='gray')
axes[1, i].axis('off')
# Add row labels
axes[0, 0].set_ylabel("Original", fontsize=12, labelpad=15)
axes[1, 0].set_ylabel("Reconstructed", fontsize=12, labelpad=15)
plt.suptitle("VAE Inference: Original vs. Reconstructed MNIST Digits", fontsize=16)
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()
Acknowledgements
✨ Google ML Developer Programs and Google Developers Program supported this work by providing Google Cloud Credits (and awesome tutorials for the Google Developer Experts)✨





Top comments (0)