Fine-Tuning Gemma 4 with Cloud Run Jobs: Unlocking Serverless GPU Power with NVIDIA RTX 6000 Pro for Pet Breed Classification
In this tutorial, you’ll learn how to fine-tune Google’s Gemma 4B model using Cloud Run Jobs and NVIDIA RTX 6000 Ada Generation GPUs for a pet breed classification task. We’ll walk through everything from dataset preparation to deploying a serverless training job — all without managing a single VM.
By the end, you’ll have a production-ready pipeline that leverages serverless GPU power, ideal for bursty, short-duration training jobs.
🛠️ Prerequisites
Before we begin, ensure you have:
- A Google Cloud Project with billing enabled
-
Cloud SDK (
gcloud) installed and authenticated - Docker installed locally
- Basic Python & PyTorch knowledge
- A Google Cloud service account with permissions:
Cloud Run AdminStorage AdminService Account User
1. Prepare Your Dataset
We’ll use a simplified version of the Stanford Dogs Dataset for pet (dog) breed classification.
Download and Structure Data
# Create directory
mkdir -p data/dogs/{train,test}
# Example: Use gsutil or wget to download images
# For demo, we'll assume you have a folder like:
# data/dogs/train/labrador/retriever-1.jpg
# data/dogs/train/poodle/poodle-1.jpg
Label structure should be folder names = class names.
2. Set Up Python Environment
Create a virtual environment and install dependencies:
python -m venv gemma-env
source gemma-env/bin/activate
pip install torch torchvision transformers datasets accelerate pillow google-cloud-storage
3. Write the Fine-Tuning Script
Create train.py:
# train.py
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
import torch
import os
from google.cloud import storage
# Load dataset
dataset = load_dataset("imagefolder", data_dir="data/dogs")
# Use a vision-capable checkpoint (Gemma is text-only — we'll use ViT instead)
# Note: Gemma is LLM-only. For image tasks, we use ViT or SigLIP.
# But this pattern applies to any Hugging Face model.
model_name = "google/vit-base-patch16-224"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(
model_name,
num_labels=10, # e.g., 10 dog breeds
id2label={i: f"breed_{i}" for i in range(10)},
label2id={f"breed_{i}": i for i in range(10)}
)
# Preprocess images
def transform(example):
inputs = feature_extractor(example["image"], return_tensors="pt")
inputs["labels"] = example["label"]
return inputs
dataset = dataset.with_transform(transform)
# DataLoader collate function
import torch
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"].squeeze() for example in examples])
labels = torch.tensor([example["labels"] for example in examples])
return {"pixel_values": pixel_values, "labels": labels}
# Training setup
training_args = TrainingArguments(
output_dir="model-output",
per_device_train_batch_size=16,
num_train_epochs=3,
save_steps=100,
logging_steps=10,
report_to="none",
remove_unused_columns=False,
dataloader_pin_memory=False, # Important for GPU in containers
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
data_collator=collate_fn,
)
print("Starting training...")
trainer.train()
# Save model
trainer.save_model("model-output")
# Optional: Upload to GCS
if os.getenv("GCS_BUCKET"):
client = storage.Client()
bucket = client.bucket(os.getenv("GCS_BUCKET"))
for root, _, files in os.walk("model-output"):
for file in files:
local_path = os.path.join(root, file)
gcs_path = os.path.join("models/gemma4-pet", file)
blob = bucket.blob(gcs_path)
blob.upload_from_filename(local_path)
print("Model uploaded to GCS.")
🔍 Note: Gemma 4B is a text-only LLM. For image classification, we use ViT as a stand-in to demonstrate the pipeline. You can adapt this for multimodal models like
FuyuorGeminilater.
4. Create Dockerfile
Create Dockerfile:
# Use PyTorch image with CUDA
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime
WORKDIR /app
COPY requirements.txt .
RUN pip install -r requirements.txt
COPY train.py .
# Allow container to run as non-root (required by Cloud Run)
ENV PYTHONUNBUFFERED=1
# Run training
ENTRYPOINT ["python", "train.py"]
Create requirements.txt:
txt
torch==2.1.0
torchvision==0.16.0
transformers==4.38.0
datasets==2.
---
☕ **Appreciative**
Top comments (0)