DEV Community

Krish Singaria
Krish Singaria

Posted on

How I bypassed PyTorch OOM errors with a Zero-Copy C++ Graph Engine

If you have ever tried to train a Graph Neural Network (GNN) on a massive dataset, you already know the pain of the "Memory Wall."

Loading a dataset like Papers100M into PyTorch Geometric almost always ends the exact same way on a standard machine: an instant 24GB+ Out-Of-Memory (OOM) allocation crash. Standard libraries try to load the entire edge list and feature matrix into RAM before moving it to the GPU.

I got tired of my laptop crashing, so I built GraphZero (v0.2.0): a custom C++ data engine that bypasses system RAM entirely and streams datasets natively from the SSD.

Here is how I built a zero-copy pipeline that lets PyTorch train on 30GB of data while allocating 0 bytes of RAM.
graphzero

PyG

🧠 The Architecture: mmap and Zero-Copy
The core philosophy of GraphZero is simple: let the Operating System do the heavy lifting.

Instead of parsing CSVs into Python lists or Pandas DataFrames, GraphZero compiles raw data into two heavily optimized binary formats:

.gl files: Stores the graph topology (edge lists).

.gd files: Stores the node features, using strict C++ template dispatching to enforce memory layouts (like FLOAT32 or INT64).

Once compiled, the engine uses POSIX mmap to memory-map the binary files. Using nanobind, we hand the raw C++ pointers directly to PyTorch as zero-copy NumPy arrays.

import graphzero as gz
import torch

# 1. Mount the zero-copy engine
fs = gz.FeatureStore("papers100M_features.gd")

# 2. Instantly map SSD data to PyTorch (RAM used: 0 Bytes)
X = torch.from_numpy(fs.get_tensor())

print(f"Feature Tensor: {X.shape} ({X.dtype})")
Enter fullscreen mode Exit fullscreen mode

⚡ The Execution: OS Page Faults and OpenMP
During a training loop (like GraphSAGE), PyTorch thinks it has a massive 50GB tensor sitting in RAM.

When the neural network requests a batch of target nodes, it indexes the mapped tensor. This triggers an OS Page Fault. The operating system automatically fetches only the required 4KB blocks from the NVMe drive.

To keep the pipeline saturated, the C++ engine uses OpenMP multi-threading for neighbor sampling (batch_random_fanout). Because this happens in C++, we release the Python GIL, allowing disk I/O, CPU sampling, and GPU math to run perfectly in parallel.

🚀 Try it out
Building GraphZero forced me to dive deep into low-level memory management, CI/CD matrix builds, and Python C-bindings.

If you want to train GNNs without melting your RAM, check out the repository. It includes an end-to-end GraphSAGE training script with a synthetic dataset generator so you can test the zero-copy mounting locally.

github repo
I would love any harsh technical feedback on the C++ architecture, or the API design!

Top comments (10)

Collapse
 
freerave profile image
freerave

Brilliant approach! Leveraging mmap to let the OS handle the paging is a huge brain move. I do have a quick technical question: Since you're relying on page faults from the NVMe SSD, how does this affect the training speed (I/O bottleneck) compared to a system that actually has enough RAM to fit the whole dataset? Is the OpenMP multi-threading enough to completely hide that I/O latency?

Collapse
 
krish_singaria profile image
Krish Singaria

Thanks! You hit the fundamental trade-off. No, OpenMP doesn't completely hide I/O latency—DDR5 RAM will always beat an NVMe SSD.
However, two things mitigate this: the OS Page Cache stores "hot" nodes in free RAM after epoch 1, and OpenMP saturates the NVMe's IOPS queue with parallel requests.
Ultimately, it’s a choice between training at 70% speed versus an instant PyTorch OOM crash!

Collapse
 
freerave profile image
freerave

That makes perfect sense! Relying on the OS Page Cache for the 'hot' nodes after epoch 1 is a very elegant fallback. And you're absolutely right—70% training speed is infinitely faster than a 0% OOM crashed run 😂. Thanks for the detailed breakdown, brilliant engineering!

Collapse
 
apex_stack profile image
Apex Stack

Really cool approach. The mmap strategy resonates with me — I run Llama 3 locally to generate content for a 100k+ page multilingual site, and memory management is a constant battle at that scale. Had to get creative with batch sizing and streaming to avoid OOM on a machine with 32GB RAM.

The 70% speed vs. OOM crash trade-off is the right framing. In my case I found a similar pattern: processing pages in streaming batches at ~60% throughput beats trying to load everything into memory and crashing halfway through a 10-hour generation run.

Curious — have you considered adding a prefetching strategy that pre-warms the page cache based on the graph's access patterns? If your neighbor sampler knows which nodes are likely to be accessed next, you could issue madvise(MADV_WILLNEED) hints ahead of time. Might close some of that 30% gap.

Collapse
 
krish_singaria profile image
Krish Singaria

well thanks for prefetching idea, i may look into this.

Collapse
 
softcypherbyte profile image
soft-cypher-byte

This is clever
Moving beyond the usual batch size reduction and gradient accumulation bandaids

The zero copy approach makes a lot of sense
Data movement between CPU and GPU is such an overlooked bottleneck
Everyone stares at compute but ignores the overhead of shuffling tensors around

Questions that came to mind
How does this compare to pytorches own memory optimization features
Things like checkpointing or max split size mb
Is this generalizable across different architectures
Or does it work best for specific patterns like transformers

Would be interesting to see benchmarks across different gpu hardware
A100 vs H100 vs consumer cards
Also curious about the autograd tradeoffs
Maintaining gradients with zero copy is tricky territory

Have you thought about open sourcing it
Would love to test this on some real workloads
The approach reminds me of tensorrt but keeping it inside the pytorch ecosystem is a nice sweet spot

Nice engineering

Collapse
 
krish_singaria profile image
Krish Singaria

well that are some hard questions to answer, so you may look about them yourself, it is open sourced (github link) , also avaliable on pypi,

pip install graphzero
Enter fullscreen mode Exit fullscreen mode
Collapse
 
harsh2644 profile image
Harsh

This is a brilliant solution to a problem every GNN practitioner has faced. The mmap + zero-copy approach is elegant letting the OS handle page faults instead of fighting against RAM limits. The fact that you're handing raw C++ pointers directly to PyTorch via nanobind is exactly the kind of systems-level thinking that makes deep learning actually practical at scale. Impressive work!

Collapse
 
softcypherbyte profile image
soft-cypher-byte • Edited

This is a fascinating deep dive into low-level optimization! The zero-copy C graph engine approach is clever - bypassing PyTorch's overhead while maintaining the computational graph is no small feat.
Key takeaways that impressed me:
Moving beyond just batch size reduction or gradient accumulation (the usual OOM bandaids)
The zero-copy philosophy - minimizing data movement between CPU/GPU is often overlooked but critical for performance
Building a custom C extension that maintains graph connectivity while reducing memory footprint

-How does this compare to PyTorch's own memory optimization features like checkpointing or max_split_size_mb?
-Is this generalizable across different model architectures, or does it work best for specific patterns (like transformers vs CNNs)?
-Did you have to make any trade-offs with autograd functionality? Maintaining gradients with zero-copy can be tricky

Would be interesting to see benchmarks across different GPU architectures (A100 vs H100 vs consumer cards)
Integration with PyTorch's memory profiling tools would make this more accessible
A fallback mechanism for operations where zero-copy isn't optimal
This kind of systems-level optimization work is exactly what the ML engineering community needs. Have you considered open-sourcing it? Would love to test it on some production workloads!
The approach reminds me of how TensorRT optimizes graphs, but keeping it within PyTorch's ecosystem is a nice middle ground. Great engineering!

Collapse
 
greazy_spoon profile image
Greazy Spoon

The goat

Some comments may only be visible to logged-in visitors. Sign in to view all comments.