DEV Community

Yuvraj Raghuvanshi
Yuvraj Raghuvanshi

Posted on • Originally published at yuvrajraghuvanshis.Medium on

Training a Classifier on Huge Dataset When RAM Is Not Your Friend

I didn’t set out to build a custom data loader. I set out to train a model on the Quick, Draw! dataset.

The data pipeline was supposed to be the boring part — the few lines you write before the interesting work starts. It ended up being most of the work, the source of the most frustrating bugs, and, in retrospect, the most interesting engineering decision of the whole project.

This is the story of why I ended up with a directory containing millions of individual .npy files, and why that turned out to be the right call.

What Quick, Draw! Actually Is

Quick, Draw! is a Google dataset of human drawings collected from a browser game where players had 20 seconds to draw a prompted word. It has 345 categories — cats, airplanes, zigzags, The Eiffel Tower — with up to 100,000 drawings per class. That’s about 50 million drawings in total.

What makes it interesting for ML, and annoying for data pipelines, is that each drawing has two representations:

Raster images  — each drawing rendered as a 28×28 grayscale bitmap, stored as a flat array of 784 values. These come in .npy files where a single file for one class contains an array of shape (N, 784). For 100,000 samples, that's 100,000 rows of 784 floats per file.

Stroke sequences  — the original drawing data: a sequence of (dx, dy, pen_state) triplets representing how the pen moved. These come in .npz files, split into train, val, and test keys. The stroke data varies in length per drawing - a simple zigzag might have 10 points, a detailed drawing of The Great Wall of China might have hundreds.

The model I wanted to build was multimodal: it would take both representations as input simultaneously, letting a CNN process the image and an LSTM process the stroke sequence, then merge their outputs for classification. Which meant the pipeline had to serve both modalities in sync, for every sample, across 345 classes.

Screenshot: a sample drawings from Quick, Draw! — both the raster image and stroke visualization side by side
Screenshot: a sample drawings from Quick, Draw! — both the raster image and stroke visualization side by side

The Naive Approach and Why It Dies

The obvious first attempt is the one-liner:

data = np.load("cat.npy") # shape: (~100000, 784)
Enter fullscreen mode Exit fullscreen mode

That loads fine for one class. You run it for a few classes, you’re still fine. Then somewhere around class 20 or 30 your process gets killed by the OOM killer, or your Jupyter kernel crashes silently, or the remote server you’ve SSH’d into drops your connection and takes your training run with it.

With 345 classes at 30,000 samples each (my chosen limit) — we’re talking about loading roughly 10 million samples into RAM at startup. At around 11% of a 128GB server’s memory for 10,000 samples per class, the math on 30,000 samples gets uncomfortable fast. And that’s before you account for the stroke data.

The real problem isn’t just peak RAM usage. It’s that loading everything upfront means you can’t start training until loading finishes, the loaded arrays stay resident for the entire run, and any shuffle operation has to work over the full dataset in memory. All of this compounds.

There’s also a subtlety with the stroke files: they come pre-split into train/val/test partitions. If you want to do your own splits (which you do, so you can control the ratio and the random seed), you need to recombine them first and re-split yourself.

So before we get to the loader itself, there are three preprocessing steps to run.

Step 1: Downloading the Data

The download script fetches both file types from Google’s Cloud Storage. The listing endpoint returns XML, which the script parses to find the URLs for the classes you’ve defined in base_classes. Downloads run in parallel using a thread pool:

with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
    executor.map(lambda url: download_file(url, download_folder), file_urls)
Enter fullscreen mode Exit fullscreen mode

Two separate calls — one for .npy raster files, one for .npz stroke files, filtered to the sketchrnn/ prefix:

download_quickdraw_files(..., file_type="npy")
download_quickdraw_files(..., file_type="npz", prefix_filter="sketchrnn/")
Enter fullscreen mode Exit fullscreen mode

Files that already exist are skipped, which matters when you’re running this on a remote server where connections drop and you have to restart.

Screenshot: terminal output during download — the [DOWNLOAD] and [SKIP] lines showing parallel fetching
Screenshot: terminal output during download — the [DOWNLOAD] and [SKIP] lines showing parallel fetching

Step 2: Recombining the Stroke Splits

Each stroke .npz file has three keys: train, val, and test. Left as-is, you're working with a subset of the available data. The fix is to concatenate them:

combined = np.concatenate([data["train"], data["val"], data["test"]], axis=0)
np.savez_compressed(out_path, strokes=combined)
Enter fullscreen mode Exit fullscreen mode

This runs in parallel across all classes using ProcessPoolExecutor. One thing worth noting: after combining, gc.collect() is called explicitly. This is a multiprocessing context and Python's garbage collector doesn't always release memory between processes the way you'd expect. Without this, a machine with moderate RAM will start sweating as dozens of processes hold combined arrays simultaneously.

Step 3: The Key Idea — One File Per Sample

This is the decision everything else depends on.

Instead of keeping each class as a single large .npy file, we explode every sample out into its own file:

dataset/processed/
  images/
    cat/
      000001.npy ← shape: (28, 28, 1)
      000002.npy
      ...
  strokes/
    cat/
      000001.npy ← shape: (130, 3)
      000002.npy
      ...
Enter fullscreen mode Exit fullscreen mode

The conversion script loops over every class, loads the class-level arrays, preprocesses each sample, and saves them individually. The index is global across all classes — not per-class — which is what keeps image and stroke files aligned:

global_idx = 0
max_samples_per_class=100_000

for label_name, _ in LABEL_MAP.items():
    images = np.load(img_path, mmap_mode="r") # note: memory-mapped
    strokes = np.load(
                stroke_path, allow_pickle=True, encoding="latin1"
              )["strokes"]

    N = min(len(images), len(strokes), max_samples_per_class)

    for i in range(N):
        idx = global_idx + i
        np.save(
          f"images/{label_name}/{idx:06d}.npy", 
          preprocess_image(images[i])
        )
        np.save(
          f"strokes/{label_name}/{idx:06d}.npy", 
          preprocess_strokes(strokes[i])
        )

    global_idx += N
Enter fullscreen mode Exit fullscreen mode

The image loading uses mmap_mode="r" - memory-mapped, so NumPy doesn't load the entire (100000, 784) array into RAM just to iterate over it row by row. The preprocessing happens at this stage, not at training time, so the generator later is just doing file reads.

This step takes a while to run. On the upside, it runs once.

Screenshot: the processed/ directory structure — showing the per-class subdirectories with numbered .npy files
Screenshot: the processed/ directory structure — showing the per-class subdirectories with numbered .npy files

What Preprocessing Actually Does

Images are straightforward. Reshape (784,) to (28, 28), divide by 255 to get [0, 1] floats, expand the channel dimension:

img = flat_img.reshape(28, 28).astype(np.float32) / 255.0
return np.expand_dims(img, axis=-1) # (28, 28, 1)
Enter fullscreen mode Exit fullscreen mode

Strokes are more involved. The raw data uses relative coordinates — each (dx, dy) is an offset from the previous point, not an absolute position. This makes sense for how drawings are recorded but not for how a model should see them. The preprocessing converts to absolute, centers the drawing at the origin, then scales to a fixed [-100, 100] range:

# Relative -> absolute
strokes[:, 0] = np.cumsum(strokes[:, 0])
strokes[:, 1] = np.cumsum(strokes[:, 1])

# Center at origin
strokes[:, 0] -= strokes[:, 0].mean()
strokes[:, 1] -= strokes[:, 1].mean()

# Scale to [-100, 100]
max_coord = max(np.abs(strokes[:, 0]).max(), np.abs(strokes[:, 1]).max())
if max_coord > 0:
    strokes[:, 0] *= 100.0 / max_coord
    strokes[:, 1] *= 100.0 / max_coord
Enter fullscreen mode Exit fullscreen mode

Stroke sequences are variable length. To get a fixed-size tensor for the LSTM, sequences are either truncated or zero-padded to 130 points. Why 130? Empirically, that covers the vast majority of drawings in the dataset without wasting too many zeros on the short ones.

The pen state column (the third feature) is left as-is — it’s already a binary indicator of whether the pen is lifted.


Screenshot: before/after visualization of a stroke — raw relative coordinates as a mess of lines, then the centered/normalized version looking like the actual drawing

The Loader

After preprocessing, the index step is fast. We walk the processed directory and collect all file paths:

for cls, label in LABEL_MAP.items():
    image_files = sorted(glob(f"{PROCESSED_DATA_DIR}/images/{cls}/*.npy"))
    stroke_files = sorted(glob(f"{PROCESSED_DATA_DIR}/strokes/{cls}/*.npy"))

N = min(len(image_files), len(stroke_files), SAMPLES_PER_CLASS)
    for i in range(N):
        images.append(image_files[i])
        strokes.append(stroke_files[i])
        labels.append(label)
Enter fullscreen mode Exit fullscreen mode

At this point, images and strokes are just lists of strings. Nothing has been loaded into memory. The total dataset - 345 classes × 30,000 samples - indexes in a few seconds.

There’s also a threshold in the config: IN_MEMORY_THRESHOLD = 30_000. If SAMPLES_PER_CLASS is below that number, the loader will actually call np.load() during indexing and store the arrays directly. For quick experiments on a subset of data, this avoids the per-sample I/O overhead at training time. For large runs, it streams from disk instead.

USE_IN_MEMORY = USE_INDIVIDUAL and (SAMPLES_PER_CLASS <= IN_MEMORY_THRESHOLD)
Enter fullscreen mode Exit fullscreen mode

Both paths feed into the same generator interface, which is a nice property — you can switch between them by changing one number.

The Generator and the tf.data Pipeline

The generator is a Python function that yields (image, stroke, one_hot_label) tuples:

def data_generator(images, strokes, labels):
    if USE_IN_MEMORY:
        for image, stroke, label in zip(images, strokes, labels):
            yield image, stroke, tf.one_hot(label, depth=NUM_CLASSES)
    elif USE_INDIVIDUAL:
        for img_path, str_path, label in zip(images, strokes, labels):
            yield (
                    np.load(img_path), 
                    np.load(str_path), 
                    tf.one_hot(label, depth=NUM_CLASSES)
                  )
Enter fullscreen mode Exit fullscreen mode

This feeds into a tf.data.Dataset via from_generator, which requires explicit output signatures - TensorFlow needs to know shapes and dtypes upfront since it can't infer them from a Python generator:

output_signature = (
    tf.TensorSpec(shape=(28, 28, 1), dtype=tf.float32),
    tf.TensorSpec(shape=(130, 3), dtype=tf.float32),
    tf.TensorSpec(shape=(NUM_CLASSES,), dtype=tf.int32),
)

ds = tf.data.Dataset.from_generator(gen, output_signature=output_signature)
Enter fullscreen mode Exit fullscreen mode

The full pipeline adds shuffling (shuffles a buffer of 10× the batch size rather than the entire dataset), repeating, batching at 512, and prefetching:

def build_dataset(images, strokes, labels, is_shuffle=False):
    ds = tf.data.Dataset.from_generator(gen, output_signature=output_signature)
    if is_shuffle:
        ds = ds.shuffle(BATCH_SIZE * 10)
    ds = ds.repeat()
    ds = ds.map(format_sample, num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(BATCH_SIZE)
    return ds.prefetch(tf.data.AUTOTUNE)
Enter fullscreen mode Exit fullscreen mode

The format_sample step reformats the yielded tuple into the dictionary format Keras expects for multi-input models:

def format_sample(img, stroke, label):
    return {"stroke_input": stroke, "image_input": img}, label
Enter fullscreen mode Exit fullscreen mode

Shuffling indices, not files, is important here. The file layout on disk stays sequential — images for cat are in one directory, images for airplane in another. The shuffle happens in the data pipeline as it reads, which avoids random I/O seeks across the disk. Sequential reads are substantially faster than random ones, and the OS page cache will warm up the recently accessed files naturally.


Screenshot: htop showing RAM usage during training — relatively flat, not growing with training time

Splitting the Dataset

The split is index-based. We shuffle a global index array once with a fixed seed, then slice it:

indices = np.arange(total)
np.random.seed(42)
np.random.shuffle(indices)

train_end = int(0.8 * total)
val_end = train_end + int(0.1 * total)
train_idx = indices[:train_end]
val_idx = indices[train_end:val_end]
test_idx = indices[val_end:]
Enter fullscreen mode Exit fullscreen mode

The 80/10/10 ratio applies across all classes since the indexing step already interleaved everything. There’s no risk of a class being entirely in the training set and absent from validation.

Validation and test datasets use .take() to consume a fixed number of batches - computed from the split sizes - since the generator repeats indefinitely:

val_ds = build_dataset(
            val_images, val_strokes, val_labels
          ).take(math.ceil(len(val_labels) / BATCH_SIZE))
test_ds = build_dataset(
            test_images, test_strokes, test_labels
          ).take(math.ceil(len(test_labels) / BATCH_SIZE))
Enter fullscreen mode Exit fullscreen mode

What Went Wrong Along the Way

File count. The processed dataset ends up with roughly 345 × 30,000 × 2 = 20.7 million files. Some filesystems handle this poorly. If you're on a filesystem with inode limits or slow directory listing (common with some HPC storage systems), the sorted(glob(...)) calls at index time can take several minutes. Structured subdirectories (one per class) help, but it's still a lot of files.

Index alignment. The global index scheme — where file names reflect position across all classes, not within a class — exists entirely to prevent a specific bug. An earlier version used per-class indices, which caused a silent alignment failure: image cat/000001.npy and stroke cat/000001.npy were always aligned, but after shuffling, the code was pulling from globally-indexed lists and the class-local numbering didn't correspond. The {idx:06d} naming ensures that whatever index you retrieve from the lists, the image and stroke file names will match.

Training on a remote server with an unstable SSH connection. The training history in the notebook has a gap. BackupAndRestore meant the model weights survived; the history object didn't. TensorBoard logs were the fallback, and the actual metrics are there - the notebook's loss and accuracy plots just show what was available from the Python history object after reconnecting. If you're doing long training runs remotely, save the history separately and frequently, not just at the end.

Memory growth with TensorFlow’s GPU allocator. By default, TensorFlow pre-allocates the entire GPU memory. For a machine shared with other users, or one running other processes, this is a problem. The fix is:

for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
Enter fullscreen mode Exit fullscreen mode

This makes TensorFlow allocate GPU memory incrementally as needed. It’s not set by default because it can slightly reduce performance in some scenarios, but for shared environments it’s basically always the right call.

What I’d Do Differently

The main thing I’d want to add is parallel file loading. Right now the generator is single-threaded — it loads one sample at a time, yields it, repeats. tf.data.AUTOTUNE on the prefetch helps by trying to keep the pipeline filled ahead of the model's consumption, but the actual I/O is sequential. Adding multiple generator workers (like PyTorch's num_workers) would reduce the time the GPU spends waiting for data.

LMDB would also be worth experimenting with. The advantage over millions of small files is that it’s a single file that supports fast key-value lookup, sequential reading, and doesn’t suffer from filesystem overhead per-entry. The disadvantage is that it complicates the setup and makes debugging harder. For this project the small-files approach was fast enough, but at larger scale it would start to matter.

A smarter caching strategy (keeping recently accessed samples in a bounded RAM buffer) would also help with the “warm up” problem. The first epoch is always slower than subsequent ones because the OS page cache starts cold. A pre-warmed in-memory buffer for the most frequently accessed samples would smooth that out.

The Part That Surprised Me

When I first sketched this out, my expectation was that disk-based loading would be noticeably slower than loading everything to RAM — enough to be a real bottleneck. It wasn’t, for a reason that only became clear after thinking about it: individual .npy file loads are fast. A (28, 28, 1) array at float32 is 3,136 bytes. A (130, 3) stroke array is 1,560 bytes. These are tiny files. The actual read time per sample is in the low microseconds, and the OS cache handles repeat accesses to recently-read files transparently.

What you trade away compared to pure in-memory loading is predictability. With everything in RAM, access time is constant. With disk loading, you’re occasionally hitting a file that isn’t cached, and that read takes longer. In practice, the prefetch buffer absorbs most of this variance. The GPU never actually sat idle waiting for data in my runs — the bottleneck was always computation, not I/O.

The other thing that surprised me was how much the single-file-per-class approach had been hiding. When everything for cat is one big (100000, 784) array, you have no choice but to load the whole thing before you can access any of it. That's a loading cost you pay every time. With individual files, you pay per sample - and you only pay for the samples you actually use.

The Notebook Setup (in case it’s useful)

One thing worth mentioning for anyone running this on a remote server: the port forwarding setup for Jupyter. If you’re SSH-ing into a machine and want to run notebooks rather than pulling .py files and running them in screen sessions, you forward the Jupyter port to localhost:

ssh -L 8888:localhost:8888 user@server_ip

# On the server:
jupyter notebook --no-browser --port=8888
Enter fullscreen mode Exit fullscreen mode

If you’re going through two layers of SSH (e.g. a department gateway server that routes to a compute node), you just carry the forwarding through:

ssh -L 8888:localhost:8888 user@gateway
# On gateway:
ssh -L 8888:localhost:8888 user@compute_node
Enter fullscreen mode Exit fullscreen mode

And for full control over Python and library versions, running the kernel inside a virtual environment is worth the setup time:

python3.12 -m venv .tens
source .tens/bin/activate
pip install jupyter ipykernel tensorflow numpy tqdm matplotlib
python -m ipykernel install --user --name=.tens
Enter fullscreen mode Exit fullscreen mode

Then you can select .tens as the kernel in Jupyter and know exactly what Python version and library versions are running - which matters if you're planning to later quantize the model and deploy it somewhere like a Raspberry Pi, where the environment constraints are much stricter.

The pipeline ended up being more engineered than I originally wanted. But it runs, it doesn’t crash, and it’ll scale to more classes or more samples per class without changes. For a dataset this size on a memory-constrained machine, that’s the bar.

The code is all in the repository if you want to look at the actual implementation rather than the edited excerpts here. I’ll make this public once the paper is accepted.

This article is rewritten using AI chatbots.

April 14, 2026

Top comments (0)