DEV Community

Cover image for Run Any HuggingFace Model on TPUs: A Beginner's Guide to TorchAX

Run Any HuggingFace Model on TPUs: A Beginner's Guide to TorchAX

TorchAX

What if you could run any HuggingFace model on TPUs — without rewriting a single line of model code?

Here is what the end result looks like:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it", torch_dtype="bfloat16")

import torchax
torchax.enable_globally()  # Enable AFTER loading the model

model.to("jax")  # That's it. Now running on JAX.
Enter fullscreen mode Exit fullscreen mode

Five lines. Your PyTorch model is now executing on JAX — with access to TPUs, JIT compilation, and automatic parallelism across devices.

In this tutorial, we will go from zero to building a working chatbot powered by a HuggingFace model running on JAX. Along the way, you will learn key JAX concepts, see real benchmarks, and understand why this approach exists.

Open Full Tutorial In Colab Open Quick Start In Colab


Why This Matters: The HuggingFace + JAX Problem

In 2024, HuggingFace removed native JAX and TensorFlow support from its transformers library to focus development on PyTorch. This left thousands of JAX users — especially those running on Google Cloud TPUs — without a straightforward way to use HuggingFace's massive model collection.

What is JAX?

If you are new to JAX, think of it as Google's high-performance numerical computing library. It looks like NumPy on the surface, but under the hood it offers three powerful capabilities:

  1. JIT Compilation — JAX can compile your Python code into optimized machine code using the XLA compiler. The first run is slower (compilation), but every subsequent call is dramatically faster.

  2. TPU Support — JAX is the native programming model for Google's Tensor Processing Units. If you want to use TPUs, JAX is the most natural path.

  3. Automatic Parallelism — JAX can automatically distribute computation across multiple devices (TPUs or GPUs) using a single-program model called gSPMD. You describe what should be sharded; the compiler figures out how.

Enter TorchAX

torchax is a library from Google that bridges PyTorch and JAX. It works by creating a special torch.Tensor subclass that secretly holds a jax.Array inside. When PyTorch operations are called on this tensor, torchax intercepts them and executes the JAX equivalent instead.

Think of it like a Trojan horse: PyTorch thinks it is working with regular tensors, but the computation is actually happening on JAX.

PyTorch Model
    |
    v
torchax.Tensor (looks like torch.Tensor)
    |
    v
jax.Array (actual computation on TPU/GPU)
Enter fullscreen mode Exit fullscreen mode

This means you can take any PyTorch model — including HuggingFace models — and run it on JAX without modifying the model code at all.

Credits: This tutorial builds on the excellent 3-part blog series by Han Qi (@qihqi), the author of torchax, and on the torchax documentation. We expand on those tutorials with beginner-friendly explanations, a different model (Gemma instead of Llama), benchmarks, and a complete Colab-ready notebook.


TorchAX vs. the Alternatives

Before diving into code, it helps to understand where torchax fits in the broader ecosystem:

Approach Effort Performance Best For
Rewrite in Flax/Equinox High (full rewrite) Native JAX speed New projects starting in JAX
torch-xla (PyTorch/XLA) Low (add XLA device) Good (XLA compiled) PyTorch training on TPUs
torchax Low (change device to 'jax') Great (JAX JIT + interop) Running HF models on JAX, mixing PyTorch + JAX
ONNX export Medium (export + runtime) Variable Cross-framework deployment

When should you use torchax? When you have a PyTorch model (especially from HuggingFace) and want to leverage JAX's JIT compilation, TPU support, or interop with JAX libraries — without rewriting the model.


Prerequisites & Setup

What you need:

  • Python 3.10+
  • Basic familiarity with PyTorch (loading models, running inference)
  • A Google Colab account (free tier works for the 1B model)

Zero-setup option: Click the Colab badge above. The notebook handles all installation automatically.

Local setup:

# 1. Install PyTorch (CPU version — torchax handles the accelerator)
pip install torch --index-url https://download.pytorch.org/whl/cpu  # Linux
# pip install torch  # macOS

# 2. Install JAX for your accelerator
pip install -U jax[tpu]     # Google Cloud TPU
# pip install -U jax[cuda12]  # NVIDIA GPU
# pip install -U jax          # CPU only

# 3. Install torchax, transformers, and flax (for JAX compatibility)
pip install -U torchax transformers flax
Enter fullscreen mode Exit fullscreen mode

Key Concepts for Beginners

Before we write code, let's demystify three JAX concepts you will encounter throughout this tutorial.

Pytrees: JAX's Data Containers

A pytree is any nested structure of Python containers (dicts, lists, tuples) with arrays as leaves. JAX uses pytrees everywhere — model weights are pytrees, function inputs/outputs are pytrees.

Think of a pytree like a shipping box with labeled compartments. JAX knows how to open standard boxes (dicts, lists, tuples), pull out all the arrays, do math on them, and put them back.

The catch: JAX does not know how to open custom boxes. HuggingFace defines custom output types like CausalLMOutputWithPast — we need to teach JAX how to unpack and repack these. This is called pytree registration, and we will see it in action shortly.

JIT Compilation: Translate Once, Run Fast Forever

JIT (Just-In-Time) compilation is like translating a recipe from English to machine code. The first time you call a JIT-compiled function, JAX traces through it, records all the operations, and compiles an optimized version. Subsequent calls skip the tracing and run the compiled version directly.

First call:  Python code → trace → compile → execute  (slow)
Second call: compiled code → execute                   (fast!)
Enter fullscreen mode Exit fullscreen mode

The speedup can be 10-100x or more. The trade-off is that the compiled function is specialized for the input shapes it was traced with — if shapes change, JAX recompiles.

Static vs. Dynamic Values

When JAX traces a function for JIT, it treats inputs as abstract shapes, not concrete values. If your code has a branch like if use_cache:, JAX cannot evaluate it during tracing because use_cache is abstract. This causes a ConcretizationTypeError.

The fix: mark such values as static (compile-time constants) so JAX knows their actual value during tracing. We will see two ways to do this: closures and static_argnums.


Step 1: Your First Forward Pass

Let's load a model and run it on JAX. We will use Gemma 3 1B IT — a small, instruction-tuned model from Google that runs comfortably on free Colab hardware.

import torch
import torchax
import jax
import time

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
model_name = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.bfloat16, device_map="cpu"
)

# Enable torchax globally AFTER model loading
# This prevents intercepting unsupported initialization ops
torchax.enable_globally()

# Move model weights to the JAX device
model.to("jax")

# Tokenize an input prompt
prompt = "The secret to baking a good cake is"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to("jax")

# Run a forward pass (eager mode)
start = time.perf_counter()
with torch.no_grad():
    outputs = model(input_ids, use_cache=False)
elapsed = time.perf_counter() - start

print(f"Output logits shape: {outputs.logits.shape}")
print(f"Eager forward pass: {elapsed:.3f}s")
Enter fullscreen mode Exit fullscreen mode

What happened:

  1. We load the model on CPU first, then call torchax.enable_globally(). This ordering is important — enabling torchax before model loading can intercept unsupported initialization ops and cause errors.
  2. model.to("jax") moves every parameter from CPU to the JAX device — just like model.to("cuda") for GPUs.
  3. The forward pass runs through PyTorch's code path, but every operation is executed by JAX under the hood.

The output logits tensor has shape (1, sequence_length, vocab_size). Each position contains a score for every token in the vocabulary — the highest score is the model's prediction for the next token.


Step 2: Speed It Up with JIT Compilation

The eager forward pass works, but it is slow — every operation goes through Python one at a time. Let's compile the model for dramatically faster inference.

The extract_jax Approach

The torchax.extract_jax() function converts a PyTorch model into a pure JAX function:

# Extract a JAX-callable function and the model weights as a pytree
weights, jax_func = torchax.extract_jax(model)
Enter fullscreen mode Exit fullscreen mode

This returns two things:

  • weights — the model's state_dict as a pytree of jax.Arrays
  • jax_func — a function with signature jax_func(weights, args_tuple, kwargs_dict)

Register HuggingFace Output Types as Pytrees

Before we can JIT this function, we need to teach JAX about HuggingFace's custom types:

from jax.tree_util import register_pytree_node
from transformers import modeling_outputs, cache_utils

# Register CausalLMOutputWithPast
def output_flatten(v):
    return v.to_tuple(), None

def output_unflatten(aux, children):
    return modeling_outputs.CausalLMOutputWithPast(*children)

register_pytree_node(
    modeling_outputs.CausalLMOutputWithPast,
    output_flatten,
    output_unflatten,
)

# Register DynamicCache
def _flatten_dynamic_cache(cache):
    return (cache.key_cache, cache.value_cache), None

def _unflatten_dynamic_cache(aux, children):
    c = cache_utils.DynamicCache()
    c.key_cache, c.value_cache = children
    return c

register_pytree_node(
    cache_utils.DynamicCache,
    _flatten_dynamic_cache,
    _unflatten_dynamic_cache,
)
Enter fullscreen mode Exit fullscreen mode

Handle Static Arguments with a Closure

The use_cache flag is a boolean that JAX cannot trace. We wrap it in a closure to make it a compile-time constant:

def forward_no_cache(weights, input_ids):
    return jax_func(weights, (input_ids,), {"use_cache": False})

jitted_forward = jax.jit(forward_no_cache)
Enter fullscreen mode Exit fullscreen mode

Benchmark: Eager vs. JIT

# Convert input to a native JAX array for jax.jit
jax_input_ids = jax.device_put(inputs["input_ids"].numpy())

# Warm up (first call triggers compilation)
res = jitted_forward(weights, jax_input_ids)
jax.block_until_ready(res)

# Benchmark 3 runs
for i in range(3):
    start = time.perf_counter()
    res = jitted_forward(weights, jax_input_ids)
    jax.block_until_ready(res)
    elapsed = time.perf_counter() - start
    print(f"Run {i}: {elapsed:.4f}s")
Enter fullscreen mode Exit fullscreen mode

Expected output (times will vary by hardware):

Run 0: 0.0142s  # Already compiled from warm-up
Run 1: 0.0038s
Run 2: 0.0035s
Enter fullscreen mode Exit fullscreen mode

The JIT-compiled version runs orders of magnitude faster than eager mode. This is the power of XLA compilation — operations are fused, memory is optimized, and the accelerator runs a single optimized program.


Step 3: The Simpler API — torchax.compile

The extract_jax + manual JIT approach gives you full control, but for most cases there is a simpler way. The catch is that torchax.compile() uses jax.jit under the hood, so we need to avoid passing dynamic boolean flags like use_cache. We wrap the model in a thin module that bakes in these constants:

import torch.nn as nn

class NoCacheModel(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def forward(self, input_ids):
        # Return only logits to avoid HuggingFace output class pytree issues
        return self.base_model(input_ids, use_cache=False, return_dict=False)[0]

# One-liner: compile the wrapped model
compiled_model = torchax.compile(NoCacheModel(model))

# Use it like a normal PyTorch model
with torch.no_grad():
    logits = compiled_model(input_ids)
Enter fullscreen mode Exit fullscreen mode

Under the hood, torchax.compile() wraps your model in a JittableModule and applies jax.jit. The first call triggers compilation; subsequent calls are fast. The NoCacheModel wrapper ensures that boolean flags are constants (not traced) and that the output is a plain tensor (not a custom HuggingFace type that needs pytree registration).


Step 4: Text Classification

Let's use our JIT-compiled model for a practical task — sentiment classification. Since Gemma is an instruction-tuned model, we can use prompt engineering:

def classify_sentiment(text, model, tokenizer):
    prompt = f"""Classify the following text as POSITIVE or NEGATIVE.
Text: "{text}"
Classification:"""

    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to("jax")

    with torch.no_grad():
        outputs = model(input_ids, use_cache=False)

    # Get the predicted next token
    next_token_logits = outputs.logits[0, -1, :]
    next_token_id = torch.argmax(next_token_logits).item()
    prediction = tokenizer.decode([next_token_id]).strip()
    return prediction

# Test it
texts = [
    "This movie was absolutely fantastic, I loved every minute!",
    "The service was terrible and the food was cold.",
    "A perfectly average experience, nothing special.",
]

for text in texts:
    result = classify_sentiment(text, model, tokenizer)
    print(f"Text: {text[:50]}...  =>  {result}")
Enter fullscreen mode Exit fullscreen mode

Step 5: Text Generation (Autoregressive Decoding)

Classification is useful, but the real power of LLMs is generating text. Let's understand how this works.

How Autoregressive Decoding Works

An LLM predicts one token at a time. Given an input of length n, it produces scores for the next token. We pick one (e.g., the highest-scoring token via greedy decoding), append it to the input, and repeat:

Iteration 1: input (1, n)     → output (1, n)     → pick token
Iteration 2: input (1, n+1)   → output (1, n+1)   → pick token
Iteration 3: input (1, n+2)   → output (1, n+2)   → pick token
...
Enter fullscreen mode Exit fullscreen mode

The problem: input shapes change every iteration. JIT compilation specializes for fixed shapes, so changing shapes means recompilation every step — worse than eager mode.

The KV Cache Solution

The KV (Key-Value) cache stores intermediate computations from previous tokens so the model only needs to process the new token each iteration:

Iteration 1: input (1, n)              → output + kv_cache(n)
Iteration 2: input (1, 1) + cache(n)   → output + kv_cache(n+1)
Iteration 3: input (1, 1) + cache(n+1) → output + kv_cache(n+2)
Enter fullscreen mode Exit fullscreen mode

With a DynamicCache, the cache grows each step — shapes still change. With a StaticCache, the cache has a fixed maximum length — shapes stay constant, making it JIT-friendly.

Implementation with StaticCache

from transformers.cache_utils import StaticCache

# Register StaticCache as a pytree
def _flatten_static_cache(cache):
    return (
        cache.key_cache, cache.value_cache
    ), (cache.config, cache.max_batch_size, cache.max_cache_len,
        getattr(cache, "device", None), getattr(cache, "dtype", None))

def _unflatten_static_cache(aux, children):
    config, max_batch_size, max_cache_len, device, dtype = aux
    kwargs = {}
    if device is not None: kwargs["device"] = device
    if dtype is not None: kwargs["dtype"] = dtype
    cache = StaticCache(config, max_batch_size, max_cache_len, **kwargs)
    cache.key_cache, cache.value_cache = children
    return cache

register_pytree_node(
    StaticCache,
    _flatten_static_cache,
    _unflatten_static_cache,
)

def generate_text(model, tokenizer, prompt, max_new_tokens=50):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to("jax")
    batch_size, seq_length = input_ids.shape

    # Create a static cache with fixed maximum length
    past_key_values = StaticCache(
        config=model.config,
        max_batch_size=1,
        max_cache_len=seq_length + max_new_tokens,
        device="jax",
        dtype=model.dtype,
    )
    cache_position = torch.arange(seq_length, device="jax")

    # Prefill: process the full prompt
    with torch.no_grad():
        logits, past_key_values = model(
            input_ids,
            cache_position=cache_position,
            past_key_values=past_key_values,
            return_dict=False,
            use_cache=True,
        )

    next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
    generated_ids = [next_token[:, 0].item()]
    cache_position = torch.tensor([seq_length], device="jax")

    # Decode: generate one token at a time
    for _ in range(max_new_tokens - 1):
        with torch.no_grad():
            logits, past_key_values = model(
                next_token,
                cache_position=cache_position,
                past_key_values=past_key_values,
                return_dict=False,
                use_cache=True,
            )
        next_token = torch.argmax(logits[:, -1], dim=-1)[:, None]
        token_id = next_token[:, 0].item()

        if token_id == tokenizer.eos_token_id:
            break
        generated_ids.append(token_id)
        cache_position += 1

    return tokenizer.decode(generated_ids, skip_special_tokens=True)

# Generate!
result = generate_text(model, tokenizer, "The secret to baking a good cake is")
print(result)
Enter fullscreen mode Exit fullscreen mode

Step 6: Distributed Inference (Tensor Parallelism)

If you have access to multiple devices (e.g., a TPU v2-8 with 8 chips, or multi-GPU), you can shard the model weights across devices for faster inference.

How Tensor Parallelism Works

In tensor parallelism, we split weight matrices across devices:

  • Column-parallel: Q, K, V, Gate, and Up projections are split along the output dimension
  • Row-parallel: O and Down projections are split along the input dimension
  • Between these two, only a single all-reduce operation is needed per layer

JAX's gSPMD handles the communication automatically — you just specify how each weight should be sharded.

Sharding the Weights

from jax.sharding import PartitionSpec as P, NamedSharding

# Create a device mesh
mesh = jax.make_mesh((jax.device_count(),), ("axis",))

def shard_weights(mesh, weights):
    sharded = {}
    for name, tensor in weights.items():
        if any(k in name for k in ["q_proj", "k_proj", "v_proj", "gate_proj", "up_proj"]):
            spec = P("axis", None)  # Column-parallel
        elif any(k in name for k in ["o_proj", "down_proj", "lm_head", "embed_tokens"]):
            spec = P(None, "axis")  # Row-parallel
        else:
            spec = P()  # Replicate (e.g., layer norms)
        sharded[name] = jax.device_put(tensor, NamedSharding(mesh, spec))
    return sharded

# Apply sharding
weights, jax_func = torchax.extract_jax(model)
weights = shard_weights(mesh, weights)

# Replicate the input across all devices
input_ids_sharded = jax.device_put(
    inputs["input_ids"], NamedSharding(mesh, P())
)
Enter fullscreen mode Exit fullscreen mode

With sharded weights, the same jax.jit-compiled function now runs in parallel across all devices. The XLA compiler automatically inserts the necessary all-reduce operations.

Note: Tensor parallelism requires a multi-device environment. On free Colab TPU (single device), this section is for illustration. Use a TPU v2-8 or multi-GPU setup to run it.


Step 7: Build a Mini Chatbot

Let's wrap everything into a simple chat function using Gemma's instruction template:

def chat(model, tokenizer, user_message, max_new_tokens=100):
    # Gemma instruction format
    prompt = f"<start_of_turn>user\n{user_message}<end_of_turn>\n<start_of_turn>model\n"
    response = generate_text(model, tokenizer, prompt, max_new_tokens)
    return response

# Example conversation
questions = [
    "What is JAX and why would I use it?",
    "Explain tensor parallelism in simple terms.",
    "Write a haiku about machine learning.",
]

for q in questions:
    print(f"User: {q}")
    print(f"Gemma: {chat(model, tokenizer, q)}")
    print()
Enter fullscreen mode Exit fullscreen mode

Swapping to a Larger Model

Everything above uses google/gemma-3-1b-it (1B parameters). To use a larger model, change the model name:

# 7B model — needs more memory (Colab Pro or multi-device)
model_name = "google/gemma-3-7b-it"
Enter fullscreen mode Exit fullscreen mode

The rest of the code remains identical. Larger models produce higher quality outputs but require more memory and compute. The 7B model benefits significantly from tensor parallelism on multi-device setups.

Other models that work well with torchax include any standard HuggingFace AutoModelForCausalLM architecture — GPT-2, Llama, Mistral, Phi, and more.


Troubleshooting

TypeError: ... is not a valid JAX type
You need to register the type as a pytree. See the registration examples above for CausalLMOutputWithPast, DynamicCache, and StaticCache.

ConcretizationTypeError: Abstract tracer value encountered
A value that changes between calls (like a boolean flag) needs to be either: (1) made static via static_argnums in jax.jit, or (2) baked into a closure as a constant.

UserWarning: A large amount of constants were captured
Model weights are being inlined as constants in the compiled graph. Pass them as explicit function arguments instead of closing over them.

RuntimeError: No available devices
Ensure JAX can see your accelerator: print(jax.devices()). In Colab, check that your runtime type is set to TPU or GPU.


Conclusion

In this tutorial, we went from zero to a working chatbot running a HuggingFace model on JAX:

  1. Forward pass — moved a PyTorch model to JAX with model.to("jax")
  2. JIT compilation — compiled for 10-100x speedup with jax.jit
  3. Text classification — used prompt engineering for sentiment analysis
  4. Text generation — implemented autoregressive decoding with StaticCache
  5. Distributed inference — sharded weights across devices with tensor parallelism
  6. Chatbot — wrapped generation in an instruction-following chat function

The key insight: torchax lets you use the entire HuggingFace ecosystem — models, tokenizers, configs — while running on JAX's high-performance backend. No model rewrites needed.

Resources

Credits

This tutorial would not be possible without the work of:

  • Han Qi (@qihqi) — author of torchax and the original HuggingFace + JAX tutorial series
  • The torchax team at Google — for building and maintaining the library
  • The HuggingFace team — for the transformers ecosystem
  • The JAX team at Google — for JAX, XLA, and TPU support

What model will you try running on TPUs first? Let me know in the comments!

Top comments (0)