Post-Training Gemma 3 for Earth Observation (EO) Understanding: A JAX Stack + TPU Pipeline for Multi-Label Sentinel Satellite Remote Sensing Scene Classification
Earth observation has no shortage of data. What it still needs is broader, more practical adoption of scalable approaches to adapt modern multimodal models to geospatial tasks — without turning training into a months-long infrastructure project.
That was the motivation behind building GemmaEarth as part of the TPU Sprint.
GemmaEarth is a domain-focused post-training and benchmarking pipeline built with Tunix, designed to adapt Google’s Gemma 3 4B IT model for Earth Observation (EO) understanding using the EarthDial: Turning Multi-sensory Earth Observations into Interactive Dialogues (2025) dataset (https://arxiv.org/abs/2412.15190). EarthDial provides a large-scale instruction-tuning dataset spanning diverse EO tasks, including classification, captioning, visual question answering (VQA), and reasoning.
While the long-term goal is to fine-tune the model across the full range of EarthDial tasks, TPU resource allocation constraints led this project to begin with a focused and practical use case: satellite remote-sensing scene classification. This initial stage leverages the BigEarthNet subset, a large-scale Sentinel benchmark for multi-label land-use and land-cover classification.
Project Contributions
GemmaEarth provides a reproducible, end-to-end workflow for adapting Gemma 3 4B IT to Earth Observation tasks, with a primary focus on multi-label satellite scene classification. The pipeline is designed to be extensible to a broader set of EO tasks within the EarthDial dataset, leveraging a TPU-native JAX stack.
Rather than focusing on a single training run, the project emphasizes a reference workflow spanning data preparation, post-training, evaluation, and benchmarking.
The implementation leverages the full JAX ecosystem (https://jaxstack.ai/):
- Tunix orchestrates the PEFT-based training pipeline
- Grain enables scalable data processing
- Optax handles optimization and scheduling
- Orbax manages checkpointing
- Qwix injects LoRA adapters for efficient fine-tuning
All experiments were executed on Google Cloud TPU v5litepod-8, with an emphasis on scalability and reproducibility.
Key contributions include:
- A reproducible JAX-based pipeline for end-to-end post-training and benchmarking
- Parameter-efficient adaptation of Gemma 3 using LoRA in a TPU-sharded setup
- Multimodal training support using image–text EO data
- TPU-native data and training pipelines for scalable experimentation
- A practical evaluation framework with interpretable metrics for multi-label classification
The JAX Stack in Practice
GemmaEarth integrates multiple components from the JAX ecosystem into a cohesive, TPU-native training pipeline, structured across key layers: data processing, model adaptation, distributed training, optimization, and checkpointing.
Data Loading and Preprocessing
Efficient data handling is critical for TPU utilization. In GemmaEarth, Grain is used to construct a high-performance, composable input pipeline for multimodal data.
This pipeline handles shuffling, transformation, batching, and iteration, ensuring a steady stream of preprocessed data into the training loop.
def _build_train_pipeline(
self,
split: Any,
image_processor: Any,
tokenizer: tokenizer_lib.Tokenizer,
num_epochs: int | None,
split_name: str,
) -> Any:
"""Construct a shuffled, mapped, and batched Grain iterable pipeline.
The pipeline applies the following operations in order: shuffle, map to
training example, map to training input (tokenize + pad), batch, repeat,
and convert to an iterable dataset.
Args:
split: HuggingFace Dataset split used as the Grain data source.
image_processor: Tunix ImageProcessor forwarded to
_to_training_example.
tokenizer: Tunix tokenizer forwarded to _to_training_input.
num_epochs: Number of times to repeat the dataset. Pass 1 (or
None) for a single pass, typically used for validation.
split_name: Human-readable label ("train" or "validation")
used in progress log messages.
Returns:
A Grain iterable dataset ready for use as train_ds or
eval_ds in GemmaEarth.train.
"""
settings = self.settings
logger.info("Building %s pipeline with %d rows...", split_name, len(split))
return (
grain.MapDataset.source(split)
.shuffle(seed=settings.shuffle_seed)
.map(lambda x: self._to_training_example(x, image_processor))
.map(lambda x: self._to_training_input(x, tokenizer))
.batch(settings.batch_size, drop_remainder=True)
.repeat(num_epochs)
.to_iter_dataset()
)
Distributed Training (Tunix + Jax)
In GemmaEarth, Tunix orchestrates the training loop, coordinating model updates, data flow, and optimization across the entire stack. It provides a unified interface for:
- State-of-the-art TPU performance through JAX-native execution
- Supervised Fine-Tuning (SFT) for instruction-based adaptation
- Reinforcement Learning (RL) for policy optimization
- Agentic RL , enabling training loops where agents interact with dynamic environments
def _build_trainer(
self,
lora_model: nnx.Module,
optimizer: optax.GradientTransformation,
max_steps: int,
tokenizer: tokenizer_lib.Tokenizer,
) -> peft_trainer.PeftTrainer:
"""Build PeftTrainer (including training config) and attach input callback.
Args:
lora_model: LoRA-augmented model to optimize.
optimizer: Optimizer instance used during training.
max_steps: Maximum training steps used to shape training config.
tokenizer: Tokenizer used when building model-input callback.
Returns:
Configured ``peft_trainer.PeftTrainer`` instance.
"""
settings = self.settings
checkpointing_options = ocp.CheckpointManagerOptions(
save_interval_steps=settings.save_interval_steps,
max_to_keep=settings.max_to_keep,
)
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
log_dir=self._tensorboard_root(),
flush_every_n_steps=20,
)
training_config = peft_trainer.TrainingConfig(
eval_every_n_steps=settings.eval_every_n_steps,
max_steps=max_steps,
checkpoint_root_directory=self._checkpoint_root(),
checkpointing_options=checkpointing_options,
metrics_logging_options=metrics_logging_options,
)
trainer = peft_trainer.PeftTrainer(
model=lora_model,
optimizer=optimizer,
training_config=training_config,
)
return trainer.with_gen_model_input_fn(self._gen_model_input_fn(tokenizer))
On the other hand, JAX provides the low-level primitives that enable efficient and hardware-aware distributed training. In particular, this pipeline leverages the device mesh abstraction , which allows precise control over how computation is distributed across TPU cores.
This design enables GemmaEarth to scale from a small number of devices to large TPU configurations without changing the training logic.
def create_mesh(self) -> jax.sharding.Mesh:
num_devices = jax.local_device_count()
logger.info("Detected %d local devices for JAX mesh creation.", num_devices)
if num_devices >= 8:
return jax.make_mesh(
(1, 4, 2),
("data", "fsdp", "tp"),
axis_types=(jax.sharding.AxisType.Auto,) * 3,
)
if num_devices >= 2:
return jax.make_mesh(
(num_devices, 1),
("fsdp", "tp"),
axis_types=(jax.sharding.AxisType.Auto,) * 2,
)
return jax.make_mesh(
(1, 1),
("fsdp", "tp"),
axis_types=(jax.sharding.AxisType.Auto,) * 2,
)
Model Adaptation (LoRA + Structured State)
Model adaptation is handled through Qwix and Flax NNX , enabling parameter-efficient fine-tuning and structured state management.
- Qwix injects LoRA adapters, enabling efficient adaptation without updating the full model
- Flax NNX + JAX manages model state and enforces sharding across the device mesh
Result: a LoRA-augmented model with explicitly sharded state , ready for distributed TPU execution.
def build_lora_model(self, mesh: jax.sharding.Mesh) -> nnx.Module:
if self.base_model is None:
raise RuntimeError("Base model was not initialized by load_base_model().")
settings = self.settings
lora_provider = qwix.LoraProvider(
module_path=self.LORA_MODULE_PATH,
rank=settings.lora_rank,
alpha=settings.lora_alpha,
)
model_input = self.base_model.get_model_input()
lora_model = qwix.apply_lora_to_model(self.base_model, lora_provider, **model_input)
with mesh:
state = nnx.state(lora_model)
pspecs = nnx.get_partition_spec(state)
sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
nnx.update(lora_model, sharded_state)
return lora_model
Optimization
Optimization is handled via Optax, a JAX-native gradient processing library.
The combination of warmup + cosine decay scheduling with AdamW ensures stable and efficient convergence during LoRA-based fine-tuning, particularly in large-scale TPU environments.
def _build_optimizer(self, max_steps: int) -> optax.GradientTransformation:
settings = self.settings
warmup_steps = max(1, int(max_steps * settings.warmup_ratio))
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=settings.learning_rate,
warmup_steps=warmup_steps,
decay_steps=max(1, max_steps),
end_value=0.0,
)
return optax.adamw(
learning_rate=lr_schedule,
weight_decay=settings.weight_decay,
)
Checkpointing and Experiment Management
Checkpointing is managed using Orbax , which provides flexible and structured model persistence.
Orbax supports:
- PyTree-based checkpoints
- Partial restoration , allowing recovery even when model structures evolve
This makes experimentation more robust and iterative.
manager = ocp.CheckpointManager(
self._checkpoint_root(),
item_handlers={"model_params": ocp.PyTreeCheckpointHandler()},
)
try:
resolved_step = step if step is not None else manager.latest_step()
if resolved_step is None:
raise RuntimeError(f"No checkpoint found under {self._checkpoint_root()}")
abstract_state = nnx.state(model)
restore_args = ocp.checkpoint_utils.construct_restore_args(target=abstract_state)
checkpoint = manager.restore(
resolved_step,
args=ocp.args.Composite(
model_params=ocp.args.PyTreeRestore(
item=abstract_state,
restore_args=restore_args,
partial_restore=True,
)
),
)
nnx.update(model, checkpoint.model_params)
return resolved_step
finally:
manager.close()
Benchmark results (1.5k samples)
After post-training, for the satellite scene classification task on a 1.5k-sample evaluation dataset, the fine-tuned model shows consistent and significant improvements over the baseline across all metrics, as illustrated in the figure below.
Performance gains are strongest in recall-oriented metrics, while improvements in F1 and Jaccard scores indicate better prediction quality and consistency.
The increase in macro_f1 confirms improvements across both frequent and less frequent classes, and higher exact match scores indicate better full-label-set predictions.
Training dynamics show stable convergence, with evaluation loss decreasing (~0.33 → ~0.22) and perplexity improving (~1.45 → ~1.25), without signs of overfitting. Performance plateaus around 25k–30k steps, suggesting convergence.
More detailed metrics and configurations are available in the GitHub repository.
Future work
The current implementation focuses on the classification subset of EarthDial (BigEarthNet). Future work will extend this pipeline to additional tasks, including captioning, VQA, and reasoning.
Given the modular design, adapting to these tasks primarily involves updating DATASET_RELATIVE_DIR and re-running the pipeline. Detailed instructions are available in the repository.
Closing Thoughts
GemmaEarth demonstrates how the JAX ecosystem can be composed into a practical, production-oriented training stack for domain-specific multimodal models.
Beyond model performance, its key contribution is an extensible engineering blueprint — enabling the transition from experimentation to scalable, reproducible systems leveraging TPU acceleration.
Thanks for taking the time to explore GemmaEarth! If you found this useful, feel free to explore the repository, experiment with the pipeline, or share feedback — I’d love to see how others extend this work.
Happy building!!
Acknowledgements
Google Cloud credits for this project were provided through the #TPUSprint
Although scripts are provided in the GitHub repository to provision GCP services — including TPU setup — you can refer to the official GCP documentation for more detailed guidance on creating and configuring a TPU VM:
https://docs.cloud.google.com/tpu/docs/create-tpu-vm



Top comments (0)