<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Ahmed Elnaggar</title>
    <description>The latest articles on DEV Community by Ahmed Elnaggar (@ahmed_elnaggar).</description>
    <link>https://dev.to/ahmed_elnaggar</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3700950%2Ffcc3958d-302e-4d73-90cb-fee4604631b2.jpg</url>
      <title>DEV Community: Ahmed Elnaggar</title>
      <link>https://dev.to/ahmed_elnaggar</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/ahmed_elnaggar"/>
    <language>en</language>
    <item>
      <title>Run Any HuggingFace Model on TPUs: A Beginner's Guide to TorchAX</title>
      <dc:creator>Ahmed Elnaggar</dc:creator>
      <pubDate>Sun, 29 Mar 2026 17:04:14 +0000</pubDate>
      <link>https://dev.to/gde/run-any-huggingface-model-on-tpus-a-beginners-guide-to-torchax-4ln0</link>
      <guid>https://dev.to/gde/run-any-huggingface-model-on-tpus-a-beginners-guide-to-torchax-4ln0</guid>
      <description>&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fcvjk9opldk0897wus527.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fcvjk9opldk0897wus527.png" alt="TorchAX" width="800" height="446"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  What if you could run any HuggingFace model on TPUs — without rewriting a single line of model code?
&lt;/h2&gt;

&lt;p&gt;Here is what the end result looks like:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;AutoTokenizer&lt;/span&gt;

&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;google/gemma-3-1b-it&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;torch_dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;bfloat16&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt;
&lt;span class="n"&gt;torchax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;enable_globally&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Enable AFTER loading the model
&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# That's it. Now running on JAX.
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Five lines. Your PyTorch model is now executing on JAX — with access to TPUs, JIT compilation, and automatic parallelism across devices.&lt;/p&gt;

&lt;p&gt;In this tutorial, we will go from zero to building a working chatbot powered by a HuggingFace model running on JAX. Along the way, you will learn key JAX concepts, see real benchmarks, and understand &lt;em&gt;why&lt;/em&gt; this approach exists.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://colab.research.google.com/github/ahmed-elnaggar/torchax-huggingface/blob/main/notebooks/torchax_huggingface_tutorial.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open Full Tutorial In Colab" width="117" height="20"&gt;&lt;/a&gt; &lt;a href="https://colab.research.google.com/github/ahmed-elnaggar/torchax-huggingface/blob/main/notebooks/torchax_quickstart.ipynb" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fcolab.research.google.com%2Fassets%2Fcolab-badge.svg" alt="Open Quick Start In Colab" width="117" height="20"&gt;&lt;/a&gt;&lt;/p&gt;




&lt;h2&gt;
  
  
  Why This Matters: The HuggingFace + JAX Problem
&lt;/h2&gt;

&lt;p&gt;In 2024, HuggingFace &lt;a href="https://github.com/huggingface/transformers/issues/25004" rel="noopener noreferrer"&gt;removed native JAX and TensorFlow support&lt;/a&gt; from its &lt;code&gt;transformers&lt;/code&gt; library to focus development on PyTorch. This left thousands of JAX users — especially those running on Google Cloud TPUs — without a straightforward way to use HuggingFace's massive model collection.&lt;/p&gt;

&lt;h3&gt;
  
  
  What is JAX?
&lt;/h3&gt;

&lt;p&gt;If you are new to JAX, think of it as Google's high-performance numerical computing library. It looks like NumPy on the surface, but under the hood it offers three powerful capabilities:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;JIT Compilation&lt;/strong&gt; — JAX can compile your Python code into optimized machine code using the XLA compiler. The first run is slower (compilation), but every subsequent call is dramatically faster.&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;TPU Support&lt;/strong&gt; — JAX is the native programming model for Google's Tensor Processing Units. If you want to use TPUs, JAX is the most natural path.&lt;/p&gt;&lt;/li&gt;
&lt;li&gt;&lt;p&gt;&lt;strong&gt;Automatic Parallelism&lt;/strong&gt; — JAX can automatically distribute computation across multiple devices (TPUs or GPUs) using a single-program model called gSPMD. You describe &lt;em&gt;what&lt;/em&gt; should be sharded; the compiler figures out &lt;em&gt;how&lt;/em&gt;.&lt;/p&gt;&lt;/li&gt;
&lt;/ol&gt;

&lt;h3&gt;
  
  
  Enter TorchAX
&lt;/h3&gt;

&lt;p&gt;&lt;a href="https://github.com/google/torchax" rel="noopener noreferrer"&gt;&lt;strong&gt;torchax&lt;/strong&gt;&lt;/a&gt; is a library from Google that bridges PyTorch and JAX. It works by creating a special &lt;code&gt;torch.Tensor&lt;/code&gt; subclass that secretly holds a &lt;code&gt;jax.Array&lt;/code&gt; inside. When PyTorch operations are called on this tensor, torchax intercepts them and executes the JAX equivalent instead.&lt;/p&gt;

&lt;p&gt;Think of it like a Trojan horse: PyTorch &lt;em&gt;thinks&lt;/em&gt; it is working with regular tensors, but the computation is actually happening on JAX.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;PyTorch Model
    |
    v
torchax.Tensor (looks like torch.Tensor)
    |
    v
jax.Array (actual computation on TPU/GPU)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This means you can take &lt;em&gt;any&lt;/em&gt; PyTorch model — including HuggingFace models — and run it on JAX without modifying the model code at all.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;Credits:&lt;/strong&gt; This tutorial builds on the excellent &lt;a href="https://huggingface.co/blog/qihqi/huggingface-jax-01" rel="noopener noreferrer"&gt;3-part blog series&lt;/a&gt; by Han Qi (&lt;a href="https://github.com/qihqi" rel="noopener noreferrer"&gt;@qihqi&lt;/a&gt;), the author of torchax, and on the &lt;a href="https://google.github.io/torchax/" rel="noopener noreferrer"&gt;torchax documentation&lt;/a&gt;. We expand on those tutorials with beginner-friendly explanations, a different model (Gemma instead of Llama), benchmarks, and a complete Colab-ready notebook.&lt;/p&gt;
&lt;/blockquote&gt;




&lt;h2&gt;
  
  
  TorchAX vs. the Alternatives
&lt;/h2&gt;

&lt;p&gt;Before diving into code, it helps to understand where torchax fits in the broader ecosystem:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Approach&lt;/th&gt;
&lt;th&gt;Effort&lt;/th&gt;
&lt;th&gt;Performance&lt;/th&gt;
&lt;th&gt;Best For&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;Rewrite in Flax/Equinox&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;High (full rewrite)&lt;/td&gt;
&lt;td&gt;Native JAX speed&lt;/td&gt;
&lt;td&gt;New projects starting in JAX&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;torch-xla (PyTorch/XLA)&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Low (add XLA device)&lt;/td&gt;
&lt;td&gt;Good (XLA compiled)&lt;/td&gt;
&lt;td&gt;PyTorch training on TPUs&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;torchax&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Low (change device to 'jax')&lt;/td&gt;
&lt;td&gt;Great (JAX JIT + interop)&lt;/td&gt;
&lt;td&gt;Running HF models on JAX, mixing PyTorch + JAX&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;strong&gt;ONNX export&lt;/strong&gt;&lt;/td&gt;
&lt;td&gt;Medium (export + runtime)&lt;/td&gt;
&lt;td&gt;Variable&lt;/td&gt;
&lt;td&gt;Cross-framework deployment&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;When should you use torchax?&lt;/strong&gt; When you have a PyTorch model (especially from HuggingFace) and want to leverage JAX's JIT compilation, TPU support, or interop with JAX libraries — without rewriting the model.&lt;/p&gt;




&lt;h2&gt;
  
  
  Prerequisites &amp;amp; Setup
&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;What you need:&lt;/strong&gt;&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Python 3.10+&lt;/li&gt;
&lt;li&gt;Basic familiarity with PyTorch (loading models, running inference)&lt;/li&gt;
&lt;li&gt;A Google Colab account (free tier works for the 1B model)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;Zero-setup option:&lt;/strong&gt; Click the Colab badge above. The notebook handles all installation automatically.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Local setup:&lt;/strong&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;&lt;span class="c"&gt;# 1. Install PyTorch (CPU version — torchax handles the accelerator)&lt;/span&gt;
pip &lt;span class="nb"&gt;install &lt;/span&gt;torch &lt;span class="nt"&gt;--index-url&lt;/span&gt; https://download.pytorch.org/whl/cpu  &lt;span class="c"&gt;# Linux&lt;/span&gt;
&lt;span class="c"&gt;# pip install torch  # macOS&lt;/span&gt;

&lt;span class="c"&gt;# 2. Install JAX for your accelerator&lt;/span&gt;
pip &lt;span class="nb"&gt;install&lt;/span&gt; &lt;span class="nt"&gt;-U&lt;/span&gt; jax[tpu]     &lt;span class="c"&gt;# Google Cloud TPU&lt;/span&gt;
&lt;span class="c"&gt;# pip install -U jax[cuda12]  # NVIDIA GPU&lt;/span&gt;
&lt;span class="c"&gt;# pip install -U jax          # CPU only&lt;/span&gt;

&lt;span class="c"&gt;# 3. Install torchax, transformers, and flax (for JAX compatibility)&lt;/span&gt;
pip &lt;span class="nb"&gt;install&lt;/span&gt; &lt;span class="nt"&gt;-U&lt;/span&gt; torchax transformers flax
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Key Concepts for Beginners
&lt;/h2&gt;

&lt;p&gt;Before we write code, let's demystify three JAX concepts you will encounter throughout this tutorial.&lt;/p&gt;

&lt;h3&gt;
  
  
  Pytrees: JAX's Data Containers
&lt;/h3&gt;

&lt;p&gt;A &lt;strong&gt;pytree&lt;/strong&gt; is any nested structure of Python containers (dicts, lists, tuples) with arrays as leaves. JAX uses pytrees everywhere — model weights are pytrees, function inputs/outputs are pytrees.&lt;/p&gt;

&lt;p&gt;Think of a pytree like a shipping box with labeled compartments. JAX knows how to open standard boxes (dicts, lists, tuples), pull out all the arrays, do math on them, and put them back.&lt;/p&gt;

&lt;p&gt;The catch: JAX does not know how to open &lt;em&gt;custom&lt;/em&gt; boxes. HuggingFace defines custom output types like &lt;code&gt;CausalLMOutputWithPast&lt;/code&gt; — we need to teach JAX how to unpack and repack these. This is called &lt;strong&gt;pytree registration&lt;/strong&gt;, and we will see it in action shortly.&lt;/p&gt;

&lt;h3&gt;
  
  
  JIT Compilation: Translate Once, Run Fast Forever
&lt;/h3&gt;

&lt;p&gt;&lt;strong&gt;JIT&lt;/strong&gt; (Just-In-Time) compilation is like translating a recipe from English to machine code. The first time you call a JIT-compiled function, JAX traces through it, records all the operations, and compiles an optimized version. Subsequent calls skip the tracing and run the compiled version directly.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;First call:  Python code → trace → compile → execute  (slow)
Second call: compiled code → execute                   (fast!)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The speedup can be 10-100x or more. The trade-off is that the compiled function is specialized for the input shapes it was traced with — if shapes change, JAX recompiles.&lt;/p&gt;

&lt;h3&gt;
  
  
  Static vs. Dynamic Values
&lt;/h3&gt;

&lt;p&gt;When JAX traces a function for JIT, it treats inputs as abstract shapes, not concrete values. If your code has a branch like &lt;code&gt;if use_cache:&lt;/code&gt;, JAX cannot evaluate it during tracing because &lt;code&gt;use_cache&lt;/code&gt; is abstract. This causes a &lt;code&gt;ConcretizationTypeError&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;The fix: mark such values as &lt;strong&gt;static&lt;/strong&gt; (compile-time constants) so JAX knows their actual value during tracing. We will see two ways to do this: closures and &lt;code&gt;static_argnums&lt;/code&gt;.&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 1: Your First Forward Pass
&lt;/h2&gt;

&lt;p&gt;Let's load a model and run it on JAX. We will use &lt;a href="https://huggingface.co/google/gemma-3-1b-it" rel="noopener noreferrer"&gt;Gemma 3 1B IT&lt;/a&gt; — a small, instruction-tuned model from Google that runs comfortably on free Colab hardware.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;

&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;AutoTokenizer&lt;/span&gt;

&lt;span class="c1"&gt;# Load model and tokenizer
&lt;/span&gt;&lt;span class="n"&gt;model_name&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;google/gemma-3-1b-it&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="n"&gt;tokenizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;AutoTokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model_name&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;AutoModelForCausalLM&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;from_pretrained&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;model_name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;torch_dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;bfloat16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device_map&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;cpu&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Enable torchax globally AFTER model loading
# This prevents intercepting unsupported initialization ops
&lt;/span&gt;&lt;span class="n"&gt;torchax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;enable_globally&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="c1"&gt;# Move model weights to the JAX device
&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Tokenize an input prompt
&lt;/span&gt;&lt;span class="n"&gt;prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;The secret to baking a good cake is&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Run a forward pass (eager mode)
&lt;/span&gt;&lt;span class="n"&gt;start&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;perf_counter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;elapsed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;perf_counter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;start&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Output logits shape: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Eager forward pass: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;elapsed&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;What happened:&lt;/strong&gt;&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;We load the model on CPU first, &lt;em&gt;then&lt;/em&gt; call &lt;code&gt;torchax.enable_globally()&lt;/code&gt;. This ordering is important — enabling torchax before model loading can intercept unsupported initialization ops and cause errors.&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;model.to("jax")&lt;/code&gt; moves every parameter from CPU to the JAX device — just like &lt;code&gt;model.to("cuda")&lt;/code&gt; for GPUs.&lt;/li&gt;
&lt;li&gt;The forward pass runs through PyTorch's code path, but every operation is executed by JAX under the hood.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The output &lt;code&gt;logits&lt;/code&gt; tensor has shape &lt;code&gt;(1, sequence_length, vocab_size)&lt;/code&gt;. Each position contains a score for every token in the vocabulary — the highest score is the model's prediction for the next token.&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 2: Speed It Up with JIT Compilation
&lt;/h2&gt;

&lt;p&gt;The eager forward pass works, but it is slow — every operation goes through Python one at a time. Let's compile the model for dramatically faster inference.&lt;/p&gt;

&lt;h3&gt;
  
  
  The extract_jax Approach
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;torchax.extract_jax()&lt;/code&gt; function converts a PyTorch model into a pure JAX function:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Extract a JAX-callable function and the model weights as a pytree
&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jax_func&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;extract_jax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This returns two things:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;code&gt;weights&lt;/code&gt; — the model's state_dict as a pytree of &lt;code&gt;jax.Array&lt;/code&gt;s&lt;/li&gt;
&lt;li&gt;
&lt;code&gt;jax_func&lt;/code&gt; — a function with signature &lt;code&gt;jax_func(weights, args_tuple, kwargs_dict)&lt;/code&gt;
&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Register HuggingFace Output Types as Pytrees
&lt;/h3&gt;

&lt;p&gt;Before we can JIT this function, we need to teach JAX about HuggingFace's custom types:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;jax.tree_util&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;register_pytree_node&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;modeling_outputs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache_utils&lt;/span&gt;

&lt;span class="c1"&gt;# Register CausalLMOutputWithPast
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;output_flatten&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;to_tuple&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;output_unflatten&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;aux&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;modeling_outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;CausalLMOutputWithPast&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;children&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;register_pytree_node&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;modeling_outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;CausalLMOutputWithPast&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;output_flatten&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;output_unflatten&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Register DynamicCache
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_flatten_dynamic_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;return &lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_cache&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_unflatten_dynamic_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;aux&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;cache_utils&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;DynamicCache&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_cache&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt;

&lt;span class="nf"&gt;register_pytree_node&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;cache_utils&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;DynamicCache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;_flatten_dynamic_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;_unflatten_dynamic_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Handle Static Arguments with a Closure
&lt;/h3&gt;

&lt;p&gt;The &lt;code&gt;use_cache&lt;/code&gt; flag is a boolean that JAX cannot trace. We wrap it in a closure to make it a compile-time constant:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward_no_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;jax_func&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,),&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;use_cache&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;})&lt;/span&gt;

&lt;span class="n"&gt;jitted_forward&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;jit&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;forward_no_cache&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Benchmark: Eager vs. JIT
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Convert input to a native JAX array for jax.jit
&lt;/span&gt;&lt;span class="n"&gt;jax_input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device_put&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;numpy&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

&lt;span class="c1"&gt;# Warm up (first call triggers compilation)
&lt;/span&gt;&lt;span class="n"&gt;res&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;jitted_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jax_input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;block_until_ready&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;res&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Benchmark 3 runs
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;start&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;perf_counter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;res&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;jitted_forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jax_input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;block_until_ready&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;res&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;elapsed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;time&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;perf_counter&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;start&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Run &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;elapsed&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;s&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Expected output (times will vary by hardware):&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight console"&gt;&lt;code&gt;&lt;span class="gp"&gt;Run 0: 0.0142s  #&lt;/span&gt;&lt;span class="w"&gt; &lt;/span&gt;Already compiled from warm-up
&lt;span class="go"&gt;Run 1: 0.0038s
Run 2: 0.0035s
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The JIT-compiled version runs orders of magnitude faster than eager mode. This is the power of XLA compilation — operations are fused, memory is optimized, and the accelerator runs a single optimized program.&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 3: The Simpler API — torchax.compile
&lt;/h2&gt;

&lt;p&gt;The &lt;code&gt;extract_jax&lt;/code&gt; + manual JIT approach gives you full control, but for most cases there is a simpler way. The catch is that &lt;code&gt;torchax.compile()&lt;/code&gt; uses &lt;code&gt;jax.jit&lt;/code&gt; under the hood, so we need to avoid passing dynamic boolean flags like &lt;code&gt;use_cache&lt;/code&gt;. We wrap the model in a thin module that bakes in these constants:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;NoCacheModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;base_model&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;base_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;base_model&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# Return only logits to avoid HuggingFace output class pytree issues
&lt;/span&gt;        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;base_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_dict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="c1"&gt;# One-liner: compile the wrapped model
&lt;/span&gt;&lt;span class="n"&gt;compiled_model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;compile&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nc"&gt;NoCacheModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# Use it like a normal PyTorch model
&lt;/span&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;compiled_model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Under the hood, &lt;code&gt;torchax.compile()&lt;/code&gt; wraps your model in a &lt;code&gt;JittableModule&lt;/code&gt; and applies &lt;code&gt;jax.jit&lt;/code&gt;. The first call triggers compilation; subsequent calls are fast. The &lt;code&gt;NoCacheModel&lt;/code&gt; wrapper ensures that boolean flags are constants (not traced) and that the output is a plain tensor (not a custom HuggingFace type that needs pytree registration).&lt;/p&gt;




&lt;h2&gt;
  
  
  Step 4: Text Classification
&lt;/h2&gt;

&lt;p&gt;Let's use our JIT-compiled model for a practical task — sentiment classification. Since Gemma is an instruction-tuned model, we can use prompt engineering:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;classify_sentiment&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt;Classify the following text as POSITIVE or NEGATIVE.
Text: &lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;
Classification:&lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;

    &lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="n"&gt;outputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Get the predicted next token
&lt;/span&gt;    &lt;span class="n"&gt;next_token_logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;outputs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt;
    &lt;span class="n"&gt;next_token_id&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;next_token_logits&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;prediction&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;next_token_id&lt;/span&gt;&lt;span class="p"&gt;]).&lt;/span&gt;&lt;span class="nf"&gt;strip&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;prediction&lt;/span&gt;

&lt;span class="c1"&gt;# Test it
&lt;/span&gt;&lt;span class="n"&gt;texts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;This movie was absolutely fantastic, I loved every minute!&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;The service was terrible and the food was cold.&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;A perfectly average experience, nothing special.&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;texts&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;result&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;classify_sentiment&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Text: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;...  =&amp;gt;  &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Step 5: Text Generation (Autoregressive Decoding)
&lt;/h2&gt;

&lt;p&gt;Classification is useful, but the real power of LLMs is generating text. Let's understand how this works.&lt;/p&gt;

&lt;h3&gt;
  
  
  How Autoregressive Decoding Works
&lt;/h3&gt;

&lt;p&gt;An LLM predicts one token at a time. Given an input of length &lt;code&gt;n&lt;/code&gt;, it produces scores for the next token. We pick one (e.g., the highest-scoring token via greedy decoding), append it to the input, and repeat:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Iteration 1: input (1, n)     → output (1, n)     → pick token
Iteration 2: input (1, n+1)   → output (1, n+1)   → pick token
Iteration 3: input (1, n+2)   → output (1, n+2)   → pick token
...
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The problem: &lt;strong&gt;input shapes change every iteration&lt;/strong&gt;. JIT compilation specializes for fixed shapes, so changing shapes means recompilation every step — worse than eager mode.&lt;/p&gt;

&lt;h3&gt;
  
  
  The KV Cache Solution
&lt;/h3&gt;

&lt;p&gt;The &lt;strong&gt;KV (Key-Value) cache&lt;/strong&gt; stores intermediate computations from previous tokens so the model only needs to process the &lt;em&gt;new&lt;/em&gt; token each iteration:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Iteration 1: input (1, n)              → output + kv_cache(n)
Iteration 2: input (1, 1) + cache(n)   → output + kv_cache(n+1)
Iteration 3: input (1, 1) + cache(n+1) → output + kv_cache(n+2)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;With a &lt;strong&gt;DynamicCache&lt;/strong&gt;, the cache grows each step — shapes still change. With a &lt;strong&gt;StaticCache&lt;/strong&gt;, the cache has a fixed maximum length — shapes stay constant, making it JIT-friendly.&lt;/p&gt;

&lt;h3&gt;
  
  
  Implementation with StaticCache
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;transformers.cache_utils&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;StaticCache&lt;/span&gt;

&lt;span class="c1"&gt;# Register StaticCache as a pytree
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_flatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;return &lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_cache&lt;/span&gt;
    &lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="nf"&gt;getattr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;device&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="nf"&gt;getattr&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;dtype&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;_unflatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;aux&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;aux&lt;/span&gt;
    &lt;span class="n"&gt;kwargs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{}&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;device&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="ow"&gt;not&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;dtype&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;
    &lt;span class="n"&gt;cache&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;StaticCache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="n"&gt;kwargs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;key_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;value_cache&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;children&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;cache&lt;/span&gt;

&lt;span class="nf"&gt;register_pytree_node&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;StaticCache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;_flatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="n"&gt;_unflatten_static_cache&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;generate_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;inputs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;return_tensors&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;pt&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;input_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;to&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;seq_length&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;

    &lt;span class="c1"&gt;# Create a static cache with fixed maximum length
&lt;/span&gt;    &lt;span class="n"&gt;past_key_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;StaticCache&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;config&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;max_batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;max_cache_len&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;seq_length&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;cache_position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Prefill: process the full prompt
&lt;/span&gt;    &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;past_key_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;cache_position&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;cache_position&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;past_key_values&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;past_key_values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;return_dict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;next_token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;generated_ids&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;next_token&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()]&lt;/span&gt;
    &lt;span class="n"&gt;cache_position&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;seq_length&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;jax&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# Decode: generate one token at a time
&lt;/span&gt;    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_new_tokens&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
            &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;past_key_values&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
                &lt;span class="n"&gt;next_token&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;cache_position&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;cache_position&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;past_key_values&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;past_key_values&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;return_dict&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
                &lt;span class="n"&gt;use_cache&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;next_token&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;argmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)[:,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;token_id&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;next_token&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;token_id&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eos_token_id&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="k"&gt;break&lt;/span&gt;
        &lt;span class="n"&gt;generated_ids&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;token_id&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;cache_position&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;generated_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;skip_special_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Generate!
&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;generate_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;The secret to baking a good cake is&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;result&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Step 6: Distributed Inference (Tensor Parallelism)
&lt;/h2&gt;

&lt;p&gt;If you have access to multiple devices (e.g., a TPU v2-8 with 8 chips, or multi-GPU), you can shard the model weights across devices for faster inference.&lt;/p&gt;

&lt;h3&gt;
  
  
  How Tensor Parallelism Works
&lt;/h3&gt;

&lt;p&gt;In tensor parallelism, we split weight matrices across devices:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Column-parallel&lt;/strong&gt;: Q, K, V, Gate, and Up projections are split along the output dimension&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Row-parallel&lt;/strong&gt;: O and Down projections are split along the input dimension&lt;/li&gt;
&lt;li&gt;Between these two, only a single &lt;strong&gt;all-reduce&lt;/strong&gt; operation is needed per layer&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;JAX's gSPMD handles the communication automatically — you just specify how each weight should be sharded.&lt;/p&gt;

&lt;h3&gt;
  
  
  Sharding the Weights
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;jax.sharding&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;PartitionSpec&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;NamedSharding&lt;/span&gt;

&lt;span class="c1"&gt;# Create a device mesh
&lt;/span&gt;&lt;span class="n"&gt;mesh&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;make_mesh&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device_count&lt;/span&gt;&lt;span class="p"&gt;(),),&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;axis&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;shard_weights&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mesh&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;sharded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{}&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tensor&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="nf"&gt;any&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;q_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;k_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;v_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;gate_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;up_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
            &lt;span class="n"&gt;spec&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;P&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;axis&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Column-parallel
&lt;/span&gt;        &lt;span class="k"&gt;elif&lt;/span&gt; &lt;span class="nf"&gt;any&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;name&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;o_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;down_proj&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;lm_head&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;embed_tokens&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
            &lt;span class="n"&gt;spec&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;P&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;axis&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# Row-parallel
&lt;/span&gt;        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;spec&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;P&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;  &lt;span class="c1"&gt;# Replicate (e.g., layer norms)
&lt;/span&gt;        &lt;span class="n"&gt;sharded&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;name&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device_put&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;NamedSharding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mesh&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;spec&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;sharded&lt;/span&gt;

&lt;span class="c1"&gt;# Apply sharding
&lt;/span&gt;&lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;jax_func&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torchax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;extract_jax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;weights&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;shard_weights&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mesh&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;weights&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Replicate the input across all devices
&lt;/span&gt;&lt;span class="n"&gt;input_ids_sharded&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;jax&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;device_put&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
    &lt;span class="n"&gt;inputs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;input_ids&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="nc"&gt;NamedSharding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;mesh&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nc"&gt;P&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;With sharded weights, the same &lt;code&gt;jax.jit&lt;/code&gt;-compiled function now runs in parallel across all devices. The XLA compiler automatically inserts the necessary all-reduce operations.&lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;&lt;strong&gt;Note:&lt;/strong&gt; Tensor parallelism requires a multi-device environment. On free Colab TPU (single device), this section is for illustration. Use a TPU v2-8 or multi-GPU setup to run it.&lt;/p&gt;
&lt;/blockquote&gt;




&lt;h2&gt;
  
  
  Step 7: Build a Mini Chatbot
&lt;/h2&gt;

&lt;p&gt;Let's wrap everything into a simple chat function using Gemma's instruction template:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;chat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;user_message&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# Gemma instruction format
&lt;/span&gt;    &lt;span class="n"&gt;prompt&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;start_of_turn&amp;gt;user&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;user_message&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;end_of_turn&amp;gt;&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;start_of_turn&amp;gt;model&lt;/span&gt;&lt;span class="se"&gt;\n&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
    &lt;span class="n"&gt;response&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;generate_text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;prompt&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;response&lt;/span&gt;

&lt;span class="c1"&gt;# Example conversation
&lt;/span&gt;&lt;span class="n"&gt;questions&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;What is JAX and why would I use it?&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Explain tensor parallelism in simple terms.&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Write a haiku about machine learning.&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;questions&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;User: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;Gemma: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="nf"&gt;chat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;tokenizer&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;






&lt;h2&gt;
  
  
  Swapping to a Larger Model
&lt;/h2&gt;

&lt;p&gt;Everything above uses &lt;code&gt;google/gemma-3-1b-it&lt;/code&gt; (1B parameters). To use a larger model, change the model name:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# 7B model — needs more memory (Colab Pro or multi-device)
&lt;/span&gt;&lt;span class="n"&gt;model_name&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;google/gemma-3-7b-it&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The rest of the code remains identical. Larger models produce higher quality outputs but require more memory and compute. The 7B model benefits significantly from tensor parallelism on multi-device setups.&lt;/p&gt;

&lt;p&gt;Other models that work well with torchax include any standard HuggingFace &lt;code&gt;AutoModelForCausalLM&lt;/code&gt; architecture — GPT-2, Llama, Mistral, Phi, and more.&lt;/p&gt;




&lt;h2&gt;
  
  
  Troubleshooting
&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;&lt;code&gt;TypeError: ... is not a valid JAX type&lt;/code&gt;&lt;/strong&gt;&lt;br&gt;
You need to register the type as a pytree. See the registration examples above for &lt;code&gt;CausalLMOutputWithPast&lt;/code&gt;, &lt;code&gt;DynamicCache&lt;/code&gt;, and &lt;code&gt;StaticCache&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;code&gt;ConcretizationTypeError: Abstract tracer value encountered&lt;/code&gt;&lt;/strong&gt;&lt;br&gt;
A value that changes between calls (like a boolean flag) needs to be either: (1) made static via &lt;code&gt;static_argnums&lt;/code&gt; in &lt;code&gt;jax.jit&lt;/code&gt;, or (2) baked into a closure as a constant.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;code&gt;UserWarning: A large amount of constants were captured&lt;/code&gt;&lt;/strong&gt;&lt;br&gt;
Model weights are being inlined as constants in the compiled graph. Pass them as explicit function arguments instead of closing over them.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;code&gt;RuntimeError: No available devices&lt;/code&gt;&lt;/strong&gt;&lt;br&gt;
Ensure JAX can see your accelerator: &lt;code&gt;print(jax.devices())&lt;/code&gt;. In Colab, check that your runtime type is set to TPU or GPU.&lt;/p&gt;




&lt;h2&gt;
  
  
  Conclusion
&lt;/h2&gt;

&lt;p&gt;In this tutorial, we went from zero to a working chatbot running a HuggingFace model on JAX:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Forward pass&lt;/strong&gt; — moved a PyTorch model to JAX with &lt;code&gt;model.to("jax")&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;JIT compilation&lt;/strong&gt; — compiled for 10-100x speedup with &lt;code&gt;jax.jit&lt;/code&gt;
&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Text classification&lt;/strong&gt; — used prompt engineering for sentiment analysis&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Text generation&lt;/strong&gt; — implemented autoregressive decoding with StaticCache&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Distributed inference&lt;/strong&gt; — sharded weights across devices with tensor parallelism&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Chatbot&lt;/strong&gt; — wrapped generation in an instruction-following chat function&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The key insight: torchax lets you use the entire HuggingFace ecosystem — models, tokenizers, configs — while running on JAX's high-performance backend. No model rewrites needed.&lt;/p&gt;

&lt;h3&gt;
  
  
  Resources
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;a href="https://github.com/google/torchax" rel="noopener noreferrer"&gt;torchax GitHub&lt;/a&gt; — library source and documentation&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://google.github.io/torchax/" rel="noopener noreferrer"&gt;torchax Docs&lt;/a&gt; — official getting started guide&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://huggingface.co/blog/qihqi/huggingface-jax-01" rel="noopener noreferrer"&gt;Original tutorial series by Han Qi&lt;/a&gt; — the 3-part blog series this tutorial builds on&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://docs.jax.dev/" rel="noopener noreferrer"&gt;JAX Documentation&lt;/a&gt; — JIT compilation, pytrees, distributed arrays&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://huggingface.co/docs/transformers/llm_optims" rel="noopener noreferrer"&gt;HuggingFace LLM Inference Optimization&lt;/a&gt; — StaticCache and torch.compile docs&lt;/li&gt;
&lt;li&gt;
&lt;a href="https://github.com/ahmed-elnaggar/torchax-huggingface" rel="noopener noreferrer"&gt;Companion GitHub repo&lt;/a&gt; — all code, notebooks, and diagrams&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  Credits
&lt;/h3&gt;

&lt;p&gt;This tutorial would not be possible without the work of:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Han Qi (&lt;a href="https://github.com/qihqi" rel="noopener noreferrer"&gt;@qihqi&lt;/a&gt;)&lt;/strong&gt; — author of torchax and the &lt;a href="https://github.com/qihqi/learning_machine/tree/main/jax-huggingface" rel="noopener noreferrer"&gt;original HuggingFace + JAX tutorial series&lt;/a&gt;
&lt;/li&gt;
&lt;li&gt;The &lt;strong&gt;torchax team at Google&lt;/strong&gt; — for building and maintaining the library&lt;/li&gt;
&lt;li&gt;The &lt;strong&gt;HuggingFace team&lt;/strong&gt; — for the transformers ecosystem&lt;/li&gt;
&lt;li&gt;The &lt;strong&gt;JAX team at Google&lt;/strong&gt; — for JAX, XLA, and TPU support&lt;/li&gt;
&lt;/ul&gt;




&lt;p&gt;What model will you try running on TPUs first? Let me know in the comments!&lt;/p&gt;

</description>
      <category>machinelearning</category>
      <category>pytorch</category>
      <category>python</category>
      <category>tutorial</category>
    </item>
  </channel>
</rss>
