DEV Community

Cover image for Your Ray Data Pipeline Works at 10K Samples. Here's Why It Crashes at 1M.
Mayank Ketkar
Mayank Ketkar

Posted on

Your Ray Data Pipeline Works at 10K Samples. Here's Why It Crashes at 1M.

There's a moment every ML infrastructure engineer knows: the evaluation pipeline that worked perfectly on 10,000 samples crashes catastrophically when you point it at a million.

The model didn't change. The GPUs are fine. The inference code is identical. The data pipeline is the bottleneck — and it fails in ways that are completely invisible at small scale.

I spent a week scaling a Ray Data pipeline from 8,600 samples to 965,000 multi-image samples for a vision-language model (Qwen3-VL). Every sample contained 5-10 video frames — so the real data volume was closer to 5-10 million images flowing through the system. Along the way, I hit five distinct distributed systems problems, each of which required a different fix.

This is the field guide.


The Pipeline

S3 (MDS shards) → Read → Preprocess (CPU) → Inference (GPU) → Results
   966 shards       download,    decode images,     Qwen3-VL        CSV +
   ~266MB each      base64       resize, format     via vLLM        metrics
Enter fullscreen mode Exit fullscreen mode

Each stage has workers. Ray Data orchestrates passing items between stages automatically — like conveyor belts between factory stations. The key insight: this is a streaming pipeline. Data should flow through continuously, not accumulate at any stage.

At 10K samples, everything fits in memory. At >1M samples with multi-image inputs, nothing does.


Problem 1: The AllToAll Barrier That Defeats Streaming

The original pipeline had a repartition() call between data loading and preprocessing:

# Step 1: Load dataset (streaming from S3)
dataset = dataset_handler.load_dataset(
    dataset_path=config.dataset_s3_path,
    num_samples=config.num_samples,
)

# Step 2: Repartition for parallel preprocessing  <-- THE BARRIER
dataset = dataset.repartition(config.num_blocks)    # AllToAll!
# Nothing downstream starts until ALL of ReadMDS finishes

# Step 3: Parallel preprocessing
processed = dataset.map(preprocessor, concurrency=160)

# Step 4: GPU inference
processed = processed.map_batches(VLMEngine, num_gpus=1, ...)
Enter fullscreen mode Exit fullscreen mode

Repartition is an AllToAll barrier. Think of it like a quality checkpoint at a car factory where every single chassis must arrive before any chassis can move to the paint shop. Even if the first 100 are ready, they sit idle while chassis #965,000 is still being welded.

At 10K samples, the repartition was instant. At 1M samples being streamed from S3, it meant: download all 256GB first, hold it in memory, then start processing. GPUs idle for the entire download.

The fix: Remove it. With a streaming datasource producing multiple blocks, preprocessing can pull rows directly:

# Step 1: Stream data from S3 (no full materialization)
dataset = dataset_handler.load_dataset(
    dataset_path=config.dataset_s3_path,
    num_samples=config.num_samples,
)

# Step 2: Parallel preprocessing (CPU-bound)
processed = dataset.map(
    SceneIQPreprocessor,
    fn_constructor_kwargs={"config": config},
    compute=ActorPoolStrategy(size=config.num_preprocessing_workers),
)

# Step 3: GPU inference — data flows here immediately
processed = processed.map_batches(
    VLMEngine,
    batch_size=config.batch_size,
    num_gpus=1,
    concurrency=config.num_vllm_engines,
)

# Step 4: Postprocess and collect results
results = processed.map(postprocess).materialize()
Enter fullscreen mode Exit fullscreen mode

Problem 2: Ray Thinks Each Task Uses 500MB. It's Actually 39GB.

The custom datasource creates ReadTask objects that download MDS shards from S3. Each task reports expected memory via BlockMetadata.size_bytes.

The original code set this to the raw shard size on S3 (~266MB). But in memory:

On S3 (compressed):     266 MB
After base64 encoding:  364 MB  (1.37x expansion)
In PyArrow table:       ~500 MB (column overhead)
Total per-task (16 tasks, ~60 shards each): ~39 GB
Enter fullscreen mode Exit fullscreen mode

The fix: Tell Ray the truth:

# Before: Ray thought each task used ~266MB (raw shard size)
# After:  Tell Ray the actual in-memory size

for shard_info in task_shards:
        task_bytes += shard_info["raw_data"]["bytes"]

# Base64 expands ~1.37x, plus PyArrow/dict overhead.
# Use 4x raw bytes as conservative in-memory estimate.
estimated_mem = task_bytes * 4

meta = BlockMetadata(
    num_rows=task_samples,
    size_bytes=estimated_mem,    # was just task_bytes
    input_files=input_files,
    exec_stats=None,
)
Enter fullscreen mode Exit fullscreen mode

This single line was the difference between "workers crash every run" and "steady-state for 12 hours."


Problem 3: 16 Tasks for 966 Shards = 64GB Per Task

Even with correct memory estimation, 16 tasks was far too few. This constant changed four times. Each wrong value crashed the cluster:

# This constant changed 4 times. Each wrong value
# crashed the cluster.
#
# 16  -> ~60 shards/task -> 64 GB -> OOM
# 48  -> ~20 shards/task -> 21 GB -> OOM
# 128 -> ~8 shards/task  -> 8.5 GB -> OOM (barely)
# 512 -> ~2 shards/task  -> 2 GB   -> Stable
#
# More tasks = less memory per task.
# Ray schedules them across CPUs, not all at once.
_DEFAULT_MAX_TASKS = 512
Enter fullscreen mode Exit fullscreen mode

Having 512 tasks doesn't mean 512 simultaneous downloads. It means 512 small, schedulable units. Ray runs a handful at a time based on available CPU.


Problem 4: CPU Oversubscription

The original config was designed for a large cluster (280 CPU). On a smaller cluster (56 CPU):

# BEFORE: CPU oversubscription (crashed)
# vLLM:         20 engines x 4 CPU = 80 CPU
# Preprocessing: 160 workers x 1 CPU = 160 CPU
# Total:                               240 CPU
# Available:                            56 CPU
# Headroom:                              0 CPU  <-- CRASH

# AFTER: Right-sized for actual hardware
num_vllm_engines: 6         # 6 x 4 CPU = 24 CPU
num_preprocessing_workers: 16  # 16 x 1 CPU = 16 CPU
# Total:                               40 CPU
# Available:                           56 CPU
# Headroom:                            16 CPU  <-- Safe
Enter fullscreen mode Exit fullscreen mode

Preprocessing is fast — 16 workers can easily keep 6 GPUs fed at ~22 samples/s.


Problem 5: Two Engines on a 64GB Pod

8 vLLM engines across 6 worker nodes means some pods get 2 engines:

Worker pod: 64 GB RAM
2 vLLM engines: ~30-40 GB (model + KV cache)
+ Object store: ~10-15 GB
+ ReadMDS data: ~5-10 GB
= ~55-65 GB --> OOM
Enter fullscreen mode Exit fullscreen mode

The fix: 1 engine per worker node. Physical constraint, not a tuning parameter.


Under the Hood: The Custom Datasource

The fix for Problems 2 and 3 lives in a custom Ray Datasource. Here's the actual code.

The architecture: the driver reads a lightweight index.json from S3 (under 1KB), groups shards into bounded ReadTasks, and each task independently downloads and decodes its shards on a Ray worker:

# Maximum number of ReadTasks to create.
# More tasks = less memory per task (avoids OOM),
# but too many adds scheduling overhead.
#
# With datasets up to ~1000 shards at ~266MB each,
# use a high limit so each task gets 1-2 shards
# (~1GB in memory), preventing worker OOM.
# Ray Data will schedule them across available CPUs,
# so only a few run concurrently.
_DEFAULT_MAX_TASKS = 512
Enter fullscreen mode Exit fullscreen mode

The get_read_tasks() method is where memory estimation happens. This is the method Ray Data calls to plan its work — and where the 4x multiplier prevents OOM:

def get_read_tasks(self, parallelism, per_task_row_limit=None):
    # Collect shards up to effective_samples limit
    active_shards = []
    remaining = self._effective_samples
    for shard in self._shard_infos:
        if remaining <= 0:
            break
        n = min(shard["samples"], remaining)
        active_shards.append((shard, n))
        remaining -= n

    num_shards = len(active_shards)
    num_tasks = min(num_shards, self._max_tasks)

    # Distribute shards across tasks (contiguous groups)
    shards_per_task = (num_shards + num_tasks - 1) // num_tasks

    read_tasks = []
    for group in shard_groups:
        task_bytes = sum(s["raw_data"]["bytes"] for s in group)

        # Base64 expands ~1.37x, plus PyArrow/dict overhead.
        # Use 4x raw bytes as conservative in-memory estimate
        # so Ray can schedule tasks without OOM-killing workers.
        estimated_mem = task_bytes * 4

        meta = BlockMetadata(
            num_rows=task_samples,
            size_bytes=estimated_mem,  # was just task_bytes!
            input_files=input_files,
            exec_stats=None,
        )

        read_tasks.append(ReadTask(
            read_fn=_make_read_fn(),
            metadata=meta,
        ))

    return read_tasks
Enter fullscreen mode Exit fullscreen mode

Each ReadTask runs this function on a Ray worker — it downloads 1-2 shards, decodes samples, and returns a PyArrow table:

def _read_mds_shards(remote_path, shard_group, offset, max_samples):
    """Download MDS shards from S3 and extract samples.
    Runs on Ray workers — each task handles 1-2 shards."""
    s3 = boto3.client("s3")
    rows = []

    for shard_info in shard_group:
        # Download shard from S3 to temp directory
        basename = shard_info["raw_data"]["basename"]
        bucket, key = _parse_s3_path(f"{remote_path}/{basename}")
        s3.download_file(bucket, key, local_path)

        # Decode samples via MDSReader
        reader = MDSReader.from_json(dirname=tmp_dir, obj=shard_info)
        for idx in range(shard_samples):
            raw_sample = reader.get_item(idx)
            row = _extract_sample(raw_sample, sample_id)
            if row is not None:
                rows.append(row)

    # Return as PyArrow table for Ray Data streaming
    return [pa.table({
        "sample_id": [r["sample_id"] for r in rows],
        "system_prompt": [r["system_prompt"] for r in rows],
        "prompt_blocks": [r["prompt_blocks"] for r in rows],
        "target": [r["target"] for r in rows],
    })]
Enter fullscreen mode Exit fullscreen mode

The key design choices: each task is self-contained (downloads its own shards, no shared state), the memory estimate is conservative (4x raw bytes), and the PyArrow table output integrates directly with Ray Data's streaming execution.


The Expanding Suitcase

Here's what most people miss. Each sample isn't text — it's 5-10 video frames:

Per sample in ReadMDS:      ~5-20 MB  (base64 strings)
Per sample in preprocessing: ~50-100 MB (PIL Images, uncompressed!)
  A 1 MB JPEG -> ~10 MB as a PIL Image
  x 10 frames = 100 MB per sample
Enter fullscreen mode Exit fullscreen mode

This is the "expanding suitcase" problem. You pack vacuum-sealed clothes (compressed images, ~266MB per shard). At the destination, you unseal them and they expand 4-10x.

The saving grace: vLLM is the bottleneck. At ~22 samples/s, it's slow enough that data doesn't pile up. Natural backpressure keeps the pipeline stable.


The Result

Parameter Original Final Change
Repartition AllToAll barrier Removed Eliminated
Memory estimation 1x (raw bytes) 4x multiplier Realistic
ReadMDS max_tasks 16 512 +3100%
vLLM engines 20 6 (1/node) -70%
Preprocessing workers 160 16 -90%
CPU utilization 100% 71% Headroom

1M multi-image samples. ~22 samples/s. 12+ hours. Zero OOM crashes.

The model code didn't change. Every fix was in how data gets to the GPUs.


Five Takeaways

  1. Repartition kills streaming. It's an AllToAll barrier that forces full materialization. Remove it if your datasource already produces multiple blocks.

  2. BlockMetadata.size_bytes is your memory contract with Ray. For image/video data, in-memory size can be 4-10x on-disk size. Set it explicitly.

  3. More tasks = less memory per task. The simplest OOM fix. 512 tasks doesn't mean 512 simultaneous downloads.

  4. Right-size for physical topology. Count GPUs, CPUs, and RAM per node, not just totals. One engine per node avoids hidden RAM contention.

  5. The data pipeline is always the bottleneck at scale. At 10K, GPU is the constraint. At 1M, the plumbing is. The inference code doesn't need to change.


Ray Data custom datasources: docs.ray.io | Performance tips: docs.ray.io

Top comments (0)