DEV Community

Orbit Websites
Orbit Websites

Posted on

Fine-Tuning Gemma 4 with Cloud Run Jobs: Unlocking Serverless GPU Power with NVIDIA RTX 6000 Pro for Pet Breed Classification

Fine-Tuning Gemma 4 with Cloud Run Jobs: Unlocking Serverless GPU Power with NVIDIA RTX 6000 Pro for Pet Breed Classification

In this tutorial, you’ll learn how to fine-tune Google’s Gemma 4B model using Cloud Run Jobs and NVIDIA RTX 6000 Ada Generation GPUs for a pet breed classification task. We’ll walk through everything from dataset preparation to deploying a serverless training job — all without managing a single VM.

By the end, you’ll have a production-ready pipeline that leverages serverless GPU power, ideal for bursty, short-duration training jobs.


🛠️ Prerequisites

Before we begin, ensure you have:

  • A Google Cloud Project with billing enabled
  • Cloud SDK (gcloud) installed and authenticated
  • Docker installed locally
  • Basic Python & PyTorch knowledge
  • A Google Cloud service account with permissions:
    • Cloud Run Admin
    • Storage Admin
    • Service Account User

1. Prepare Your Dataset

We’ll use a simplified version of the Stanford Dogs Dataset for pet (dog) breed classification.

Download and Structure Data

# Create directory
mkdir -p data/dogs/{train,test}

# Example: Use gsutil or wget to download images
# For demo, we'll assume you have a folder like:
# data/dogs/train/labrador/retriever-1.jpg
# data/dogs/train/poodle/poodle-1.jpg
Enter fullscreen mode Exit fullscreen mode

Label structure should be folder names = class names.


2. Set Up Python Environment

Create a virtual environment and install dependencies:

python -m venv gemma-env
source gemma-env/bin/activate
pip install torch torchvision transformers datasets accelerate pillow google-cloud-storage
Enter fullscreen mode Exit fullscreen mode

3. Write the Fine-Tuning Script

Create train.py:

# train.py
from transformers import AutoFeatureExtractor, AutoModelForImageClassification, TrainingArguments, Trainer
from datasets import load_dataset
import torch
import os
from google.cloud import storage

# Load dataset
dataset = load_dataset("imagefolder", data_dir="data/dogs")

# Use a vision-capable checkpoint (Gemma is text-only — we'll use ViT instead)
# Note: Gemma is LLM-only. For image tasks, we use ViT or SigLIP.
# But this pattern applies to any Hugging Face model.

model_name = "google/vit-base-patch16-224"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(
    model_name,
    num_labels=10,  # e.g., 10 dog breeds
    id2label={i: f"breed_{i}" for i in range(10)},
    label2id={f"breed_{i}": i for i in range(10)}
)

# Preprocess images
def transform(example):
    inputs = feature_extractor(example["image"], return_tensors="pt")
    inputs["labels"] = example["label"]
    return inputs

dataset = dataset.with_transform(transform)

# DataLoader collate function
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"].squeeze() for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

# Training setup
training_args = TrainingArguments(
    output_dir="model-output",
    per_device_train_batch_size=16,
    num_train_epochs=3,
    save_steps=100,
    logging_steps=10,
    report_to="none",
    remove_unused_columns=False,
    dataloader_pin_memory=False,  # Important for GPU in containers
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collate_fn,
)

print("Starting training...")
trainer.train()

# Save model
trainer.save_model("model-output")

# Optional: Upload to GCS
if os.getenv("GCS_BUCKET"):
    client = storage.Client()
    bucket = client.bucket(os.getenv("GCS_BUCKET"))
    for root, _, files in os.walk("model-output"):
        for file in files:
            local_path = os.path.join(root, file)
            gcs_path = os.path.join("models/gemma4-pet", file)
            blob = bucket.blob(gcs_path)
            blob.upload_from_filename(local_path)
    print("Model uploaded to GCS.")
Enter fullscreen mode Exit fullscreen mode

🔍 Note: Gemma 4B is a text-only LLM. For image classification, we use ViT as a stand-in to demonstrate the pipeline. You can adapt this for multimodal models like Fuyu or Gemini later.


4. Create Dockerfile

Create Dockerfile:

# Use PyTorch image with CUDA
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime

WORKDIR /app

COPY requirements.txt .
RUN pip install -r requirements.txt

COPY train.py .

# Allow container to run as non-root (required by Cloud Run)
ENV PYTHONUNBUFFERED=1

# Run training
ENTRYPOINT ["python", "train.py"]
Enter fullscreen mode Exit fullscreen mode

Create requirements.txt:


txt
torch==2.1.0
torchvision==0.16.0
transformers==4.38.0
datasets==2.

---

☕ **Appreciative**
Enter fullscreen mode Exit fullscreen mode

Top comments (0)