DEV Community

Cover image for Fine-Tune Any HuggingFace Model like Gemma on TPUs with TorchAX

Fine-Tune Any HuggingFace Model like Gemma on TPUs with TorchAX

What if you could fine-tune any HuggingFace model on TPUs — using PyTorch code?

Here is what the end result looks like:

import torchax as tx
import torchax.train

# One function: forward → loss → gradients → optimizer update
step_fn = tx.train.make_train_step(model_fn, loss_fn, optimizer)

# Training loop
for batch in dataloader:
    loss, params, opt_state = step_fn(params, buffers, opt_state, batch, batch["labels"])
Enter fullscreen mode Exit fullscreen mode

Your PyTorch model. JAX's training primitives. Running on TPU. No rewrite needed.

In the first part of this series, we ran HuggingFace models on JAX for fast inference. Now we take the next step: training. We will instruction-tune Gemma 3 1B on the Databricks Dolly 15k dataset using LoRA and torchax's functional training API — all on a free Colab TPU.

Open Full Tutorial In Colab Open Quick Start In Colab


Why Train on TPUs?

Google's Tensor Processing Units (TPUs) are purpose-built for matrix operations — the bread and butter of deep learning. Free Colab gives you access to a TPU v2-8 with ~15GB of high-bandwidth memory. That is enough to fine-tune a 1B parameter model with LoRA.

But training on TPUs traditionally meant rewriting your model in JAX (Flax, Equinox) or using PyTorch/XLA. torchax offers a third path: keep your PyTorch model, but use JAX's functional training primitives.

How torchax Training Differs from Standard PyTorch

Standard PyTorch torchax
loss.backward() jax.value_and_grad(loss_fn)(params, ...)
optimizer.step() optax.apply_updates(params, updates)
Model holds its own state Params and buffers are separate pytrees
Eager execution JIT-compiled training steps

The key difference: functional training. Instead of calling loss.backward() and optimizer.step() on a stateful model, torchax separates the model into immutable weight pytrees and passes them through pure functions. This is what enables JAX's jax.jit to compile the entire training step into a single optimized program.


Prerequisites & Setup

What you need:

  • Python 3.10+
  • Basic familiarity with PyTorch and HuggingFace transformers
  • A Google Colab account (free tier works with LoRA)

Zero-setup option: Click the Colab badge above. The notebook handles all installation automatically.

Local setup:

# PyTorch CPU (torchax handles the accelerator via JAX)
pip install torch --index-url https://download.pytorch.org/whl/cpu

# JAX + all training dependencies in a single pip call
pip install -U 'jax[tpu]' torchax transformers flax peft datasets optax   # TPU
# pip install -U 'jax[cuda12]' torchax transformers flax peft datasets optax  # GPU
Enter fullscreen mode Exit fullscreen mode

Colab note: The notebook installs packages and automatically restarts the runtime, since Colab pre-loads an older JAX that stays cached in memory until restart.


Key Concepts for Training

Before writing code, let's understand the four concepts that make torchax training work.

1. Param/Buffer Separation

JAX's jax.value_and_grad needs to know which inputs to differentiate. In standard PyTorch, the model owns its weights. In torchax training, we explicitly separate:

  • params — trainable parameters (get gradients)
  • buffers — everything else (frozen weights, running stats, constants)
params = {n: p for n, p in model.named_parameters() if p.requires_grad}
frozen = {n: p for n, p in model.named_parameters() if not p.requires_grad}
buffers = dict(model.named_buffers())
buffers.update(frozen)
Enter fullscreen mode Exit fullscreen mode

For LoRA, params contains only the tiny adapter weights (~0.5% of the model). For full fine-tuning, it contains everything.

2. optax Optimizers

Unlike PyTorch optimizers (which carry hidden mutable state), optax optimizers are pure functions:

# PyTorch: hidden state inside optimizer
optimizer.step()

# optax: explicit state, no hidden pockets
updates, new_opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
Enter fullscreen mode Exit fullscreen mode

This functional design means the optimizer state is just another pytree that flows through the training step — perfect for jax.jit.

3. make_train_step

torchax.train.make_train_step() is the central API. It composes three pieces into a single JIT-compilable function:

  1. model_fn — a pure function: (weights, buffers, batch) → output
  2. loss_fn — extracts the scalar loss: (output, labels) → loss
  3. optimizer — an optax optimizer

The result is step_fn(params, buffers, opt_state, batch, labels) → (loss, new_params, new_opt_state).

Under the hood, this uses jax.value_and_grad for efficient gradient computation and optax.apply_updates for weight updates — all compiled into a single XLA program.

4. Full Fine-Tuning vs LoRA

Full Fine-Tuning LoRA
Trainable params All (~2B) Tiny adapters (~0.5%)
Memory ~18-20 GB ~5-7 GB
Speed Slower Faster
Quality Higher ceiling Nearly as good
Free Colab TPU Tight / may OOM Fits comfortably

LoRA (Low-Rank Adaptation) freezes the base model and adds small trainable matrices to attention layers. Instead of updating the full weight matrix W, it learns a low-rank decomposition: W + (α/r) × B·A where A and B are tiny matrices.

For free Colab, LoRA is the recommended path.


Step 1: Load and Prepare the Dataset

We use Databricks Dolly 15k — 15,000 human-written instruction-response pairs across 7 categories (QA, summarization, brainstorming, etc.).

import datasets as hf_datasets
from transformers import AutoTokenizer

MODEL_NAME = "google/gemma-3-1b-it"
DATASET_NAME = "databricks/databricks-dolly-15k"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

raw_dataset = hf_datasets.load_dataset(DATASET_NAME, split="train")
Enter fullscreen mode Exit fullscreen mode

Each example has an instruction, optional context, response, and category. We format these into Gemma's chat template:

def format_example(example):
    user_content = example["instruction"]
    if example.get("context", ""):
        user_content += f"\n\nContext: {example['context']}"

    messages = [
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": example["response"]},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False)
    return {"text": text}
Enter fullscreen mode Exit fullscreen mode

Then tokenize and create dataloaders:

from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling

# Subset, split, tokenize
subset = raw_dataset.shuffle(seed=42).select(range(2200))
split = subset.train_test_split(test_size=200, seed=42)

def tokenize_example(example):
    formatted = format_example(example)
    return tokenizer(formatted["text"], padding="max_length", max_length=512, truncation=True)

train_tokenized = split["train"].map(tokenize_example, remove_columns=split["train"].column_names)
eval_tokenized = split["test"].map(tokenize_example, remove_columns=split["test"].column_names)

collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
train_dataloader = DataLoader(train_tokenized, shuffle=True, collate_fn=collator, batch_size=2)
eval_dataloader = DataLoader(eval_tokenized, shuffle=False, collate_fn=collator, batch_size=2)
Enter fullscreen mode Exit fullscreen mode

Step 2: Load the Model and Apply LoRA

Here is where the torchax pattern matters: load the model with torchax disabled, then enable it before moving to JAX.

import torch
import torchax as tx
import peft

# Load model with torchax disabled to avoid intercepting init ops
with tx.disable_temporarily():
    model = transformers.AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, torch_dtype=torch.bfloat16
    )

# Sync pad_token_id so loss computation properly ignores padding
model.config.pad_token_id = tokenizer.pad_token_id
Enter fullscreen mode Exit fullscreen mode

Why disable? HuggingFace model initialization uses operations (like in-place tensor filling) that torchax does not support. Disabling torchax during loading keeps everything on CPU, then we move to JAX after.

Now apply LoRA:

peft_config = peft.LoraConfig(
    task_type=peft.TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,                             # Rank of the LoRA matrices
    lora_alpha=16,                   # Scaling factor
    lora_dropout=0.0,                # 0.0 for bfloat16 numerical stability
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # All attention layers
)
model = peft.get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Output: trainable params: 5,767,168 || all params: 2,619,206,656 || trainable%: 0.22%
Enter fullscreen mode Exit fullscreen mode

Only 0.22% of parameters are trainable — that is the power of LoRA.

Finally, enable torchax and move to the JAX device:

tx.enable_accuracy_mode()  # Float32 accumulation for bfloat16 stability
tx.enable_globally()
device = torch.device("jax")
model.to(device)
model.train()
Enter fullscreen mode Exit fullscreen mode

Step 3: Baseline Evaluation

Before training, we measure the model's performance to compare against later:

import math

def evaluate_loss(model, dataloader, device, max_batches=50):
    model.eval()
    total_loss, total_batches = 0.0, 0
    with torch.no_grad():
        for i, batch in enumerate(dataloader):
            if i >= max_batches:
                break
            # Drop attention_mask — Gemma's sliding window attention produces NaN
            # with padded masks on torchax/JAX. Labels already mask padding with -100.
            batch = {k: v.to(device) for k, v in batch.items() if k != "attention_mask"}
            outputs = model(**batch)
            total_loss += outputs.loss.item()
            total_batches += 1
    model.train()
    avg_loss = total_loss / max(total_batches, 1)
    return avg_loss, math.exp(min(avg_loss, 100))

baseline_loss, baseline_ppl = evaluate_loss(model, eval_dataloader, device)
print(f"Baseline loss: {baseline_loss:.4f}, perplexity: {baseline_ppl:.2f}")
Enter fullscreen mode Exit fullscreen mode

We also generate sample responses for qualitative comparison. For fast generation, we register StaticCache as a JAX pytree and use KV-cached decoding — only the new token is processed each step instead of the full sequence (~50x faster):

from transformers.cache_utils import StaticCache
from jax.tree_util import register_pytree_node

def _flatten_static_cache(cache):
    return (cache.key_cache, cache.value_cache), (
        cache.config, cache.max_batch_size, cache.max_cache_len,
        getattr(cache, "device", None), getattr(cache, "dtype", None),
    )

def _unflatten_static_cache(aux, children):
    config, max_batch_size, max_cache_len, dev, dtype = aux
    kwargs = {}
    if dev is not None: kwargs["device"] = dev
    if dtype is not None: kwargs["dtype"] = dtype
    sc = StaticCache(config, max_batch_size, max_cache_len, **kwargs)
    sc.key_cache, sc.value_cache = children
    return sc

register_pytree_node(StaticCache, _flatten_static_cache, _unflatten_static_cache)
Enter fullscreen mode Exit fullscreen mode

The generation function uses prefill (process full prompt) then per-token decode with the cache and a tqdm progress bar:

from tqdm.auto import tqdm

def generate_response(model, tokenizer, instruction, device, max_new_tokens=100):
    messages = [{"role": "user", "content": instruction}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(device)
    seq_len = input_ids.shape[1]

    kv = StaticCache(config=model.config, max_batch_size=1,
                     max_cache_len=seq_len + max_new_tokens,
                     device=device, dtype=torch.bfloat16)
    pos = torch.arange(seq_len, device=device)

    model.eval()
    with torch.no_grad():
        # Prefill: process full prompt, populate cache
        logits, kv = model(input_ids, cache_position=pos, past_key_values=kv,
                           return_dict=False, use_cache=True)
        tok = torch.argmax(logits[:, -1], dim=-1)[:, None]
        generated = [tok[:, 0].item()]
        pos = torch.tensor([seq_len], device=device)

        # Decode: one token at a time using cached keys/values
        for _ in tqdm(range(max_new_tokens - 1), desc="Generating", leave=False):
            logits, kv = model(tok, cache_position=pos, past_key_values=kv,
                               return_dict=False, use_cache=True)
            tok = torch.argmax(logits[:, -1], dim=-1)[:, None]
            tid = tok[:, 0].item()
            if tid == tokenizer.eos_token_id:
                break
            generated.append(tid)
            pos += 1

    model.train()
    return tokenizer.decode(generated, skip_special_tokens=True)
Enter fullscreen mode Exit fullscreen mode

Step 4: Set Up Functional Training

This is where torchax diverges from standard PyTorch. We separate the model, create an optax optimizer, and compose everything into a JIT-compiled training step.

Separate params and buffers

import optax
import torchax.train

params = {n: p for n, p in model.named_parameters() if p.requires_grad}
buffers = dict(model.named_buffers())
frozen_params = {n: p for n, p in model.named_parameters() if not p.requires_grad}
buffers.update(frozen_params)
Enter fullscreen mode Exit fullscreen mode

Create the optimizer

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0, peak_value=1e-4, warmup_steps=50, decay_steps=500
)
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule, weight_decay=0.01),
)
opt_state = tx.interop.call_jax(optimizer.init, params)
Enter fullscreen mode Exit fullscreen mode

Note tx.interop.call_jax — this bridges optax's JAX calls with torchax tensors.

Define model_fn and loss_fn

def model_fn(weights, buffers, batch):
    """Stateless forward pass using functional_call."""
    return torch.func.functional_call(
        model, {**weights, **buffers}, args=(), kwargs=batch
    )

def loss_fn(model_output, labels):
    """Extract loss from HuggingFace model output."""
    return model_output.loss
Enter fullscreen mode Exit fullscreen mode

torch.func.functional_call runs the model as a pure function — no hidden state, just inputs and outputs. This is what enables JAX to trace and compile it.

Compose into a training step

step_fn = tx.train.make_train_step(model_fn, loss_fn, optimizer)
Enter fullscreen mode Exit fullscreen mode

That single line creates a function that does: forward pass → loss computation → gradient calculation → optimizer update — all compiled into one XLA program.


Step 5: The Training Loop

import time
from tqdm.auto import tqdm

torch.manual_seed(42)
train_losses = []
start_time = time.time()

for epoch in range(1):
    pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
    for step, batch in pbar:
        # Drop attention_mask — Gemma's sliding window attention produces NaN with
        # padded masks on torchax/JAX. Labels already mask padding with -100.
        batch = {k: v.to(device) for k, v in batch.items() if k != "attention_mask"}

        loss, params, opt_state = step_fn(
            params, buffers, opt_state, batch, batch["labels"]
        )

        train_losses.append(loss.item())
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

elapsed = time.time() - start_time
print(f"Training complete! {len(train_losses)} steps in {elapsed:.0f}s")
Enter fullscreen mode Exit fullscreen mode

What to expect:

  • Step 1: ~30-60 seconds (JAX compiles the entire training step)
  • Steps 2+: ~1-3 seconds each (running the compiled program)
  • Total: ~20-40 minutes for 2000 samples with LoRA on free Colab TPU

The first step is slow because JAX traces through the entire model, loss computation, gradient calculation, and optimizer update — then compiles it all into a single optimized XLA program. Every subsequent step reuses this compiled program.


Step 6: Evaluate the Improvement

After training, we compare against our baseline:

# Load trained params back into model
with torch.no_grad():
    for name, param in params.items():
        parts = name.split(".")
        obj = model
        for part in parts[:-1]:
            obj = getattr(obj, part)
        setattr(obj, parts[-1], torch.nn.Parameter(param))

final_loss, final_ppl = evaluate_loss(model, eval_dataloader, device)

print(f"{'Metric':<20} {'Before':>10} {'After':>10}")
print(f"{'Loss':<20} {baseline_loss:>10.4f} {final_loss:>10.4f}")
print(f"{'Perplexity':<20} {baseline_ppl:>10.2f} {final_ppl:>10.2f}")
Enter fullscreen mode Exit fullscreen mode

You should see loss decrease and perplexity improve after training. The qualitative comparison (generated responses before vs. after) is even more telling — the fine-tuned model produces more focused, instruction-following responses.


Step 7: Save and Reload

Save

Convert JAX arrays back to CPU tensors and save using HuggingFace's standard format:

import numpy as np

save_dir = "./fine_tuned_model"

with torch.no_grad():
    cpu_state_dict = {
        name: torch.tensor(np.array(p)).contiguous()
        for name, p in params.items()
    }
    # safe_serialization=False avoids a safetensors/torchax C-extension conflict on reload
    model.save_pretrained(save_dir, state_dict=cpu_state_dict, safe_serialization=False)

tokenizer.save_pretrained(save_dir)
Enter fullscreen mode Exit fullscreen mode

For LoRA, this saves only the tiny adapter weights (~20MB). For full fine-tuning, it saves the entire model (~4GB).

Reload

with tx.disable_temporarily():
    # For LoRA: load base model + adapters separately
    reloaded_model = transformers.AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, torch_dtype=torch.bfloat16
    )
    # torch_device="cpu" forces PEFT to load adapter weights on CPU,
    # avoiding a safetensors/torchax C-extension conflict.
    reloaded_model = peft.PeftModel.from_pretrained(reloaded_model, save_dir, torch_device="cpu")

reloaded_model.to(device)
reloaded_model.eval()
Enter fullscreen mode Exit fullscreen mode

The pattern is the same as loading: disable torchax, load on CPU, then move to JAX. For LoRA models, you load the base model first, then attach the saved adapters with PeftModel.from_pretrained(). The torch_device="cpu" ensures PEFT loads weights through PyTorch's standard path rather than safetensors' C extension, which conflicts with torchax.


Full Fine-Tuning: When LoRA Is Not Enough

The notebook supports full fine-tuning by changing one setting:

TRAINING_MODE = "full"
Enter fullscreen mode Exit fullscreen mode

This trains all parameters instead of just the LoRA adapters. The trade-off is much higher memory usage. To make it fit on free Colab TPU:

  • AdaFactor optimizer — uses ~50% less memory than AdamW (stores only row/column statistics instead of per-parameter moments)
  • Reduced sequence lengthMAX_SEQ_LEN = 256 halves activation memory
  • Smaller batch sizeBATCH_SIZE = 1 with higher gradient accumulation steps
USE_ADAFACTOR = True
USE_GRADIENT_CHECKPOINTING = True

if TRAINING_MODE == "full" and USE_ADAFACTOR:
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adafactor(learning_rate=schedule),
    )
else:
    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adamw(learning_rate=schedule, weight_decay=0.01),
    )
Enter fullscreen mode Exit fullscreen mode

Full fine-tuning gives a higher quality ceiling but LoRA gets you 90%+ of the way with a fraction of the compute.


Troubleshooting

Error Cause Fix
OutOfMemoryError Model + optimizer too large Switch to LoRA, reduce BATCH_SIZE or MAX_SEQ_LEN
TypeError: not a valid JAX type Custom HuggingFace type not registered Register with jax.tree_util.register_pytree_node()
Loss is NaN Numerical instability in bfloat16 1. Call tx.enable_accuracy_mode() before tx.enable_globally(). 2. Reduce LR (try 1e-4). 3. Set lora_dropout=0.0. 4. Add optax.clip_by_global_norm(1.0).
Slow first step Normal — JAX JIT compilation Wait ~30-60s; subsequent steps are fast
make_train_step error API mismatch Update: pip install -U torchax

The Big Picture: Inference + Training

With the inference tutorial and this training tutorial, you now have the complete torchax story:

  1. Run any HuggingFace model on TPU (model.to("jax"))
  2. Benchmark with JIT compilation (10-100x speedup)
  3. Fine-tune with LoRA or full training (make_train_step)
  4. Save and reload for production inference

All using PyTorch code. No JAX rewrite needed.


Resources

Credits

Top comments (0)