So you've got a model architecture in mind, maybe a fine-tuning job on a massive LLM, and you look at the memory requirements. A 100B parameter model in full FP32 precision needs roughly 400GB just for the parameters. Add optimizer states (Adam stores two additional copies), gradients, and activations — you're looking at well over a terabyte of memory. Your single A100 has 80GB.
You close the terminal and reconsider your life choices.
I've been there. And a recent paper called MegaTrain caught my eye because it reportedly tackles exactly this problem: full-precision training of 100B+ parameter models on a single GPU. Let me walk through the underlying problem, why naive approaches fail, and the techniques that actually make this possible.
The Memory Wall Problem
Let's do the math on why large model training blows up your GPU memory. For a model with N parameters trained with Adam in FP32:
-
Parameters:
4Nbytes (FP32 = 4 bytes per param) -
Gradients:
4Nbytes -
Adam optimizer states (m and v):
8Nbytes (two FP32 copies) -
Activations: variable, but often several multiples of
N
For a 100B parameter model, the parameters alone need ~400GB. Total training memory easily exceeds 1.6TB. Even an H100 with 80GB of HBM isn't close.
# Quick sanity check on memory requirements
param_count = 100e9 # 100 billion
bytes_per_param = 4 # FP32
param_memory_gb = (param_count * bytes_per_param) / (1024**3)
adam_states_gb = (param_count * bytes_per_param * 2) / (1024**3) # m and v
gradient_memory_gb = (param_count * bytes_per_param) / (1024**3)
total_gb = param_memory_gb + adam_states_gb + gradient_memory_gb
print(f"Parameters: {param_memory_gb:.1f} GB")
print(f"Adam states: {adam_states_gb:.1f} GB")
print(f"Gradients: {gradient_memory_gb:.1f} GB")
print(f"Total (excl. activations): {total_gb:.1f} GB")
# Output:
# Parameters: 372.5 GB
# Adam states: 745.1 GB
# Gradients: 372.5 GB
# Total (excl. activations): 1490.1 GB
That's ~1.5TB before activations. This is the memory wall.
Why the Naive Solutions Fall Short
Mixed Precision Helps, But Not Enough
BF16/FP16 training cuts parameter and gradient memory in half. Great. You're still at ~750GB for optimizer states (Adam's master copy stays in FP32). A single GPU still can't hold this.
Model Parallelism Requires Multiple GPUs
Tensor parallelism, pipeline parallelism, ZeRO-style sharding — these are the industry standard answers. But they all assume you have multiple GPUs. If you're a researcher with one machine, or you're trying to minimize cloud costs, you need a different strategy.
Vanilla CPU Offloading Is Too Slow
The idea of offloading tensors to CPU RAM (which is cheap and abundant — 1-2TB is common on workstations) has been around for a while. DeepSpeed ZeRO-Infinity and similar frameworks support it. The problem? PCIe bandwidth becomes the bottleneck.
A PCIe 4.0 x16 link gives you ~32 GB/s bidirectional. Shuffling hundreds of gigabytes back and forth between CPU and GPU every training step is brutal. Naive offloading can make training 10-50x slower than keeping everything on GPU.
This is where the real engineering challenge lies.
The Solution: Smart Offloading and Overlap
The core insight behind approaches like MegaTrain — and this applies to several systems in this space — is that you don't need everything on the GPU at the same time. You need to be smart about what lives where, and when things move.
Here are the key techniques:
1. Layer-by-Layer Processing with Prefetching
Instead of loading the entire model onto the GPU, you process one layer (or a small group of layers) at a time. While the GPU is computing on layer i, you're simultaneously transferring layer i+1 from CPU to GPU over PCIe.
# Conceptual pseudocode for overlapped offloading
import torch
from concurrent.futures import ThreadPoolExecutor
def train_step_with_offloading(model_layers, input_batch):
executor = ThreadPoolExecutor(max_workers=1)
# Start prefetching the first layer
future = executor.submit(prefetch_to_gpu, model_layers[0])
activations = input_batch
saved_activations = [] # for backward pass
for i, layer in enumerate(model_layers):
# Wait for current layer to be on GPU
gpu_layer = future.result()
# Start prefetching next layer while we compute
if i + 1 < len(model_layers):
future = executor.submit(prefetch_to_gpu, model_layers[i + 1])
# Forward pass on current layer
saved_activations.append(activations.detach())
activations = gpu_layer(activations)
# Offload layer params back to CPU (we're done with them for now)
offload_to_cpu(gpu_layer)
return activations, saved_activations
The key: if your compute time per layer exceeds the transfer time for the next layer, you hide the data movement entirely. The GPU never stalls waiting for data.
2. Activation Checkpointing (Recomputation)
Storing activations for every layer during the forward pass is memory-expensive. Activation checkpointing (also called gradient checkpointing) trades compute for memory: you only store activations at certain "checkpoint" layers and recompute the rest during the backward pass.
PyTorch has built-in support for this:
import torch
from torch.utils.checkpoint import checkpoint
class CheckpointedTransformerBlock(torch.nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
def forward(self, x):
# Recomputes this layer's forward during backward
# instead of storing activations
return checkpoint(self.layer, x, use_reentrant=False)
This can reduce activation memory from O(N_layers) to O(sqrt(N_layers)) with the right checkpointing strategy.
3. Optimizer State Partitioning
Adam's optimizer states are the biggest memory hog. The approach is to keep them in CPU memory permanently and only transfer the relevant slice to GPU when updating a specific layer's parameters. Since optimizer updates are element-wise (no cross-layer dependencies), this partitions naturally along layer boundaries.
4. Efficient Memory Management
Pre-allocating GPU buffers and reusing them across layers avoids the overhead of repeated CUDA allocations. You essentially maintain a fixed-size "working set" on the GPU that's large enough for your biggest layer, and rotate data through it.
What This Actually Gets You
With these techniques combined, GPU memory usage becomes proportional to your largest single layer plus some working buffers — not your total model size. A 100B parameter model might have individual layers that fit in a few gigabytes. That's well within an 80GB GPU's capacity.
The tradeoff is throughput. Even with perfect overlap, you're limited by:
- PCIe bandwidth for layers that are too large to hide behind compute
- Recomputation overhead from activation checkpointing
- CPU-side optimizer step speed
According to the MegaTrain paper, this reportedly achieves reasonable training throughput while maintaining full FP32 precision — no quantization compromises. I haven't benchmarked it myself yet, so I'd encourage you to check the paper for their specific numbers.
Getting Started with CPU Offloading Today
If you want to experiment with these ideas right now, the most accessible path is through existing frameworks:
- DeepSpeed ZeRO-Infinity: Supports CPU and NVMe offloading with overlap. Well-tested at scale.
- PyTorch FSDP: Has CPU offloading support built in as of recent versions.
- Hugging Face Accelerate: Wraps these features with a friendlier API.
The fundamentals are the same: partition state across CPU/GPU, overlap transfer with compute, checkpoint activations aggressively.
Prevention Tips (Avoiding the Memory Wall Earlier)
Before you reach for heroic offloading strategies:
-
Profile first. Use
torch.cuda.memory_summary()to understand where your memory actually goes. It's often not what you expect. - Try gradient accumulation. Smaller micro-batches with accumulated gradients can slash activation memory.
- Consider LoRA/QLoRA for fine-tuning. If you don't need full-parameter training, parameter-efficient methods can reduce memory by 10-100x.
- Check your data pipeline. I've seen cases where data loading buffers were eating 10+ GB of GPU memory unnecessarily.
The MegaTrain paper represents an interesting push toward making full-precision, full-parameter training accessible on minimal hardware. Whether that specific throughput tradeoff makes sense for your use case depends on your patience and your cloud bill. But the underlying techniques — smart offloading, compute-transfer overlap, activation checkpointing — are worth understanding regardless. They show up everywhere in modern training infrastructure.
Top comments (1)
I really like the image, how did you make it ?