Fine-Tuning Gemma 4 with Cloud Run Jobs: Unlocking Serverless GPU Power for Pet Breed Classification with NVIDIA RTX 6000 Pro
Serverless computing has evolved beyond simple HTTP workloads. With Cloud Run Jobs, Google Cloud now lets you run batch and training tasks—with GPUs—in a fully managed, scalable, and cost-efficient way.
In this tutorial, you’ll learn how to fine-tune Google’s Gemma 4B model on a custom pet breed classification dataset using Cloud Run Jobs and an NVIDIA RTX 6000 Ada Generation GPU—all without managing a single server.
We’ll walk through:
- Preparing a dataset of pet images
- Building a lightweight training script using Hugging Face Transformers
- Containerizing the job with GPU support
- Deploying and running it on Cloud Run Jobs
Let’s get started.
🔧 Prerequisites
Before we begin, ensure you have:
- A Google Cloud Project with billing enabled
-
Cloud SDK (
gcloud) installed and authenticated - Docker installed locally (for testing)
- Python 3.10+
- Basic knowledge of PyTorch and Hugging Face
Enable required APIs:
gcloud services enable \
run.googleapis.com \
artifactregistry.googleapis.com \
cloudbuild.googleapis.com
🐶 Step 1: Prepare Your Dataset
We’ll use a subset of the Oxford-IIIT Pet Dataset, which contains 37 pet breeds with labeled images.
Download and Organize Data
# download_data.py
import os
import requests
import tarfile
from pathlib import Path
DATASET_URL = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz"
LABELS_URL = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz"
DATA_DIR = Path("data")
def download_extract(url: str, dest: Path):
dest.mkdir(exist_ok=True)
filename = url.split("/")[-1]
filepath = dest / filename
if not filepath.exists():
print(f"Downloading {filename}...")
response = requests.get(url, stream=True)
with open(filepath, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Extracting {filename}...")
with tarfile.open(filepath, "r:gz") as tar:
tar.extractall(dest)
if __name__ == "__main__":
download_extract(DATASET_URL, DATA_DIR)
download_extract(LABELS_URL, DATA_DIR)
Run it:
python download_data.py
Now structure your data into train/val folders using labels from annotations/list.txt. Here’s a simplified version:
# organize_data.py
from pathlib import Path
data_dir = Path("data")
images_dir = data_dir / "images"
labels_file = data_dir / "annotations" / "list.txt"
for line in open(labels_file).read().splitlines()[6:]: # skip header
if len(line) > 0 and line[0] != "#":
parts = line.split()
img_name = parts[0] + ".jpg"
breed = parts[0].rsplit("_", 1)[0].lower()
src = images_dir / img_name
dest_dir = Path("dataset/train") / breed
dest_dir.mkdir(parents=True, exist_ok=True)
if src.exists():
(dest_dir / img_name).write_bytes(src.read_bytes())
💡 You can split into
train/andval/later usingsklearn.model_selection.train_test_split.
🤖 Step 2: Build the Training Script
We’ll fine-tune Gemma 4B for image classification using Vision Transformer (ViT) as the image encoder, since Gemma is text-only. We’ll use OpenCLIP or ViT + linear head.
But wait: Gemma is a language model. For image classification, we need a vision model.
✅ Correction: We’ll use ViT + linear classifier instead. Gemma isn’t suitable for vision tasks directly.
Let’s pivot: We’ll train a Vision Transformer (ViT) model using transformers and torch for pet breed classification.
Install Dependencies
pip install torch torchvision transformers datasets accelerate pillow
Training Script (train.py)
python
# train.py
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
from datasets import Dataset
from PIL import Image
import torch
import os
import glob
from sklearn.model_selection import train_test_split
import numpy as np
# Load feature extractor and model
model_name = "google/vit-base-patch16-224"
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(
model_name,
num_labels=37,
id2label={i: lb for i, lb in enumerate(sorted({p.parent.name for p in Path("dataset/train").rglob("*.jpg")}))},
label2id={lb: i for i, lb in enumerate(sorted({p.parent.name for p in Path("dataset/train").rglob("*.jpg")}))}
)
# Load images
def load_images(root_dir):
images, labels = [], []
label_map = {lb: i for i, lb in enumerate(sorted({p.parent.name
---
☕ As a professional developer, I believe in giving back to the community through open source contributions and free resources - your support via https://ko-fi.com/orbitwebsites helps me continue this work and is greatly appreciated.
Top comments (0)