DEV Community

Henry Ruiz, PhD for Google Developer Experts

Posted on • Originally published at Medium on

Post-Training Gemma 3 for Earth Observation (EO) Understanding: A JAX Stack + TPU Pipeline for…

Post-Training Gemma 3 for Earth Observation (EO) Understanding: A JAX Stack + TPU Pipeline for Multi-Label Sentinel Satellite Remote Sensing Scene Classification

Earth observation has no shortage of data. What it still needs is broader, more practical adoption of scalable approaches to adapt modern multimodal models to geospatial tasks — without turning training into a months-long infrastructure project.

That was the motivation behind building GemmaEarth as part of the TPU Sprint.

GitHub - haruiz/gemmaearth

GemmaEarth is a domain-focused post-training and benchmarking pipeline built with Tunix, designed to adapt Google’s Gemma 3 4B IT model for Earth Observation (EO) understanding using the EarthDial: Turning Multi-sensory Earth Observations into Interactive Dialogues (2025) dataset (https://arxiv.org/abs/2412.15190). EarthDial provides a large-scale instruction-tuning dataset spanning diverse EO tasks, including classification, captioning, visual question answering (VQA), and reasoning.

While the long-term goal is to fine-tune the model across the full range of EarthDial tasks, TPU resource allocation constraints led this project to begin with a focused and practical use case: satellite remote-sensing scene classification. This initial stage leverages the BigEarthNet subset, a large-scale Sentinel benchmark for multi-label land-use and land-cover classification.


EarthDial dataset

Project Contributions

GemmaEarth provides a reproducible, end-to-end workflow for adapting Gemma 3 4B IT to Earth Observation tasks, with a primary focus on multi-label satellite scene classification. The pipeline is designed to be extensible to a broader set of EO tasks within the EarthDial dataset, leveraging a TPU-native JAX stack.

Rather than focusing on a single training run, the project emphasizes a reference workflow spanning data preparation, post-training, evaluation, and benchmarking.

The implementation leverages the full JAX ecosystem (https://jaxstack.ai/):

  • Tunix orchestrates the PEFT-based training pipeline
  • Grain enables scalable data processing
  • Optax handles optimization and scheduling
  • Orbax manages checkpointing
  • Qwix injects LoRA adapters for efficient fine-tuning

All experiments were executed on Google Cloud TPU v5litepod-8, with an emphasis on scalability and reproducibility.

Key contributions include:

  • A reproducible JAX-based pipeline for end-to-end post-training and benchmarking
  • Parameter-efficient adaptation of Gemma 3 using LoRA in a TPU-sharded setup
  • Multimodal training support using image–text EO data
  • TPU-native data and training pipelines for scalable experimentation
  • A practical evaluation framework with interpretable metrics for multi-label classification


JAX AI stack

The JAX Stack in Practice

GemmaEarth integrates multiple components from the JAX ecosystem into a cohesive, TPU-native training pipeline, structured across key layers: data processing, model adaptation, distributed training, optimization, and checkpointing.

Data Loading and Preprocessing

Efficient data handling is critical for TPU utilization. In GemmaEarth, Grain is used to construct a high-performance, composable input pipeline for multimodal data.

This pipeline handles shuffling, transformation, batching, and iteration, ensuring a steady stream of preprocessed data into the training loop.

Source: https://github.com/haruiz/gemma_earth/blob/5fd2e4c0b6cf3c65eac026bd5aecc136e81072b8/src/gemma_earth/dataset.py#L519

def _build_train_pipeline(
        self,
        split: Any,
        image_processor: Any,
        tokenizer: tokenizer_lib.Tokenizer,
        num_epochs: int | None,
        split_name: str,
    ) -> Any:
        """Construct a shuffled, mapped, and batched Grain iterable pipeline.

        The pipeline applies the following operations in order: shuffle, map to
        training example, map to training input (tokenize + pad), batch, repeat,
        and convert to an iterable dataset.

        Args:
            split: HuggingFace Dataset split used as the Grain data source.
            image_processor: Tunix ImageProcessor forwarded to
                _to_training_example.
            tokenizer: Tunix tokenizer forwarded to _to_training_input.
            num_epochs: Number of times to repeat the dataset. Pass 1 (or
                None) for a single pass, typically used for validation.
            split_name: Human-readable label ("train" or "validation")
                used in progress log messages.

        Returns:
            A Grain iterable dataset ready for use as train_ds or
            eval_ds in GemmaEarth.train.
        """
        settings = self.settings
        logger.info("Building %s pipeline with %d rows...", split_name, len(split))
        return (
            grain.MapDataset.source(split)
            .shuffle(seed=settings.shuffle_seed)
            .map(lambda x: self._to_training_example(x, image_processor))
            .map(lambda x: self._to_training_input(x, tokenizer))
            .batch(settings.batch_size, drop_remainder=True)
            .repeat(num_epochs)
            .to_iter_dataset()
        )
Enter fullscreen mode Exit fullscreen mode

Distributed Training (Tunix + Jax)

In GemmaEarth, Tunix orchestrates the training loop, coordinating model updates, data flow, and optimization across the entire stack. It provides a unified interface for:

  • State-of-the-art TPU performance through JAX-native execution
  • Supervised Fine-Tuning (SFT) for instruction-based adaptation
  • Reinforcement Learning (RL) for policy optimization
  • Agentic RL , enabling training loops where agents interact with dynamic environments

Source: https://github.com/haruiz/gemma_earth/blob/5fd2e4c0b6cf3c65eac026bd5aecc136e81072b8/src/gemma_earth/trainers/base.py#L386

def _build_trainer(
        self,
        lora_model: nnx.Module,
        optimizer: optax.GradientTransformation,
        max_steps: int,
        tokenizer: tokenizer_lib.Tokenizer,
    ) -> peft_trainer.PeftTrainer:
        """Build PeftTrainer (including training config) and attach input callback.

        Args:
            lora_model: LoRA-augmented model to optimize.
            optimizer: Optimizer instance used during training.
            max_steps: Maximum training steps used to shape training config.
            tokenizer: Tokenizer used when building model-input callback.

        Returns:
            Configured ``peft_trainer.PeftTrainer`` instance.
        """
        settings = self.settings
        checkpointing_options = ocp.CheckpointManagerOptions(
            save_interval_steps=settings.save_interval_steps,
            max_to_keep=settings.max_to_keep,
        )
        metrics_logging_options = metrics_logger.MetricsLoggerOptions(
            log_dir=self._tensorboard_root(),
            flush_every_n_steps=20,
        )
        training_config = peft_trainer.TrainingConfig(
            eval_every_n_steps=settings.eval_every_n_steps,
            max_steps=max_steps,
            checkpoint_root_directory=self._checkpoint_root(),
            checkpointing_options=checkpointing_options,
            metrics_logging_options=metrics_logging_options,
        )
        trainer = peft_trainer.PeftTrainer(
            model=lora_model,
            optimizer=optimizer,
            training_config=training_config,
        )
        return trainer.with_gen_model_input_fn(self._gen_model_input_fn(tokenizer))
Enter fullscreen mode Exit fullscreen mode

On the other hand, JAX provides the low-level primitives that enable efficient and hardware-aware distributed training. In particular, this pipeline leverages the device mesh abstraction , which allows precise control over how computation is distributed across TPU cores.

This design enables GemmaEarth to scale from a small number of devices to large TPU configurations without changing the training logic.

Source: https://github.com/haruiz/gemma_earth/blob/5fd2e4c0b6cf3c65eac026bd5aecc136e81072b8/src/gemma_earth/trainers/base.py#L279

def create_mesh(self) -> jax.sharding.Mesh:
  num_devices = jax.local_device_count()
  logger.info("Detected %d local devices for JAX mesh creation.", num_devices)

  if num_devices >= 8:
    return jax.make_mesh(
      (1, 4, 2),
      ("data", "fsdp", "tp"),
      axis_types=(jax.sharding.AxisType.Auto,) * 3,
    )

  if num_devices >= 2:
    return jax.make_mesh(
      (num_devices, 1),
      ("fsdp", "tp"),
      axis_types=(jax.sharding.AxisType.Auto,) * 2,
    )

  return jax.make_mesh(
    (1, 1),
    ("fsdp", "tp"),
    axis_types=(jax.sharding.AxisType.Auto,) * 2,
  )
Enter fullscreen mode Exit fullscreen mode

Model Adaptation (LoRA + Structured State)

Model adaptation is handled through Qwix and Flax NNX , enabling parameter-efficient fine-tuning and structured state management.

  • Qwix injects LoRA adapters, enabling efficient adaptation without updating the full model
  • Flax NNX + JAX manages model state and enforces sharding across the device mesh

Result: a LoRA-augmented model with explicitly sharded state , ready for distributed TPU execution.

Source: https://github.com/haruiz/gemma_earth/blob/5fd2e4c0b6cf3c65eac026bd5aecc136e81072b8/src/gemma_earth/trainers/base.py#L328v

def build_lora_model(self, mesh: jax.sharding.Mesh) -> nnx.Module:
  if self.base_model is None:
    raise RuntimeError("Base model was not initialized by load_base_model().")

  settings = self.settings
  lora_provider = qwix.LoraProvider(
    module_path=self.LORA_MODULE_PATH,
    rank=settings.lora_rank,
    alpha=settings.lora_alpha,
  )

  model_input = self.base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(self.base_model, lora_provider, **model_input)

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model
Enter fullscreen mode Exit fullscreen mode

Optimization

Optimization is handled via Optax, a JAX-native gradient processing library.

The combination of warmup + cosine decay scheduling with AdamW ensures stable and efficient convergence during LoRA-based fine-tuning, particularly in large-scale TPU environments.

Source: https://github.com/haruiz/gemma_earth/blob/5fd2e4c0b6cf3c65eac026bd5aecc136e81072b8/src/gemma_earth/trainers/base.py#L361

def _build_optimizer(self, max_steps: int) -> optax.GradientTransformation:
  settings = self.settings
  warmup_steps = max(1, int(max_steps * settings.warmup_ratio))
  lr_schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=settings.learning_rate,
    warmup_steps=warmup_steps,
    decay_steps=max(1, max_steps),
    end_value=0.0,
  )

  return optax.adamw(
    learning_rate=lr_schedule,
    weight_decay=settings.weight_decay,
  )
Enter fullscreen mode Exit fullscreen mode

Checkpointing and Experiment Management

Checkpointing is managed using Orbax , which provides flexible and structured model persistence.

Orbax supports:

  • PyTree-based checkpoints
  • Partial restoration , allowing recovery even when model structures evolve

This makes experimentation more robust and iterative.

Source: https://github.com/haruiz/gemma_earth/blob/5fd2e4c0b6cf3c65eac026bd5aecc136e81072b8/src/gemma_earth/trainers/base.py#L579

manager = ocp.CheckpointManager(
  self._checkpoint_root(),
  item_handlers={"model_params": ocp.PyTreeCheckpointHandler()},
)
try:
  resolved_step = step if step is not None else manager.latest_step()
  if resolved_step is None:
    raise RuntimeError(f"No checkpoint found under {self._checkpoint_root()}")

  abstract_state = nnx.state(model)
  restore_args = ocp.checkpoint_utils.construct_restore_args(target=abstract_state)

  checkpoint = manager.restore(
    resolved_step,
    args=ocp.args.Composite(
      model_params=ocp.args.PyTreeRestore(
        item=abstract_state,
        restore_args=restore_args,
        partial_restore=True,
      )
    ),
  )

  nnx.update(model, checkpoint.model_params)
  return resolved_step
finally:
  manager.close()
Enter fullscreen mode Exit fullscreen mode

Benchmark results (1.5k samples)

After post-training, for the satellite scene classification task on a 1.5k-sample evaluation dataset, the fine-tuned model shows consistent and significant improvements over the baseline across all metrics, as illustrated in the figure below.

Performance gains are strongest in recall-oriented metrics, while improvements in F1 and Jaccard scores indicate better prediction quality and consistency.

The increase in macro_f1 confirms improvements across both frequent and less frequent classes, and higher exact match scores indicate better full-label-set predictions.

Training dynamics show stable convergence, with evaluation loss decreasing (~0.33 → ~0.22) and perplexity improving (~1.45 → ~1.25), without signs of overfitting. Performance plateaus around 25k–30k steps, suggesting convergence.

More detailed metrics and configurations are available in the GitHub repository.

Future work

The current implementation focuses on the classification subset of EarthDial (BigEarthNet). Future work will extend this pipeline to additional tasks, including captioning, VQA, and reasoning.

Given the modular design, adapting to these tasks primarily involves updating DATASET_RELATIVE_DIR and re-running the pipeline. Detailed instructions are available in the repository.

Closing Thoughts

GemmaEarth demonstrates how the JAX ecosystem can be composed into a practical, production-oriented training stack for domain-specific multimodal models.

Beyond model performance, its key contribution is an extensible engineering blueprint — enabling the transition from experimentation to scalable, reproducible systems leveraging TPU acceleration.

Thanks for taking the time to explore GemmaEarth! If you found this useful, feel free to explore the repository, experiment with the pipeline, or share feedback — I’d love to see how others extend this work.

Happy building!!

Acknowledgements

Google Cloud credits for this project were provided through the #TPUSprint

Although scripts are provided in the GitHub repository to provision GCP services — including TPU setup — you can refer to the official GCP documentation for more detailed guidance on creating and configuring a TPU VM:

https://docs.cloud.google.com/tpu/docs/create-tpu-vm


Top comments (0)