DEV Community

Cover image for Why Your PyTorch Training Crawls on a Beefy GPU (And How to Fix It)
Alan West
Alan West

Posted on

Why Your PyTorch Training Crawls on a Beefy GPU (And How to Fix It)

Last month I was helping a friend debug a training loop that was running at maybe 15% GPU utilization on an A100. Fifteen percent. On a card that costs more than my first car. He'd already tried bumping the batch size, swapping the optimizer, and rewriting the data loader — nothing moved the needle.

This is one of those frustrating problems where the obvious knobs do nothing, because the obvious knobs aren't where the bottleneck lives. So let's actually walk through how to figure out why your model is slow, instead of just throwing batch sizes at the wall.

The three regimes nobody tells you about

When a deep learning workload is slow, it's almost always slow for one of three reasons. Horace He laid this out really clearly in his "Making Deep Learning Go Brrrr From First Principles" post back in 2022, and the framing has stuck with me ever since:

  • Compute-bound — you're actually saturating the matmul units. Rare. Usually only happens with huge dense layers.
  • Memory-bandwidth-bound — the GPU is mostly waiting on data to move between HBM and the SMs. Way more common than people realize.
  • Overhead-bound — Python, the framework dispatcher, or kernel launch latency is dominating. Death by a thousand papercuts.

The punchline: most "my model is slow" problems are not compute-bound, even though that's where everyone instinctively looks first. If you're running a transformer with a bunch of small ops between the big matmuls, you're probably stuck in regime two or three.

Step 1: Figure out which regime you're in

Don't guess. Profile. PyTorch's built-in profiler will tell you most of what you need:

import torch
from torch.profiler import profile, ProfilerActivity

model = MyModel().cuda()
x = torch.randn(32, 3, 224, 224, device='cuda')

# Warm up — first iterations include cudnn autotuning and allocator setup
for _ in range(5):
    model(x)

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
) as prof:
    for _ in range(10):
        out = model(x)
        out.sum().backward()

# Sort by self CUDA time to see what's actually burning GPU cycles
print(prof.key_averages().table(sort_by='self_cuda_time_total', row_limit=20))
Enter fullscreen mode Exit fullscreen mode

What you're looking for:

  • If a few big gemm / conv kernels dominate → likely compute-bound, and you should be happy.
  • If you see a sea of tiny kernels (add, mul, relu, layer_norm components, etc.) eating real time → memory-bandwidth-bound.
  • If CPU time is way higher than CUDA time, or kernel launches are spaced out with gaps → overhead-bound.

For the friend's model, the profile showed hundreds of tiny pointwise kernels per step. Classic memory bandwidth problem.

Step 2: The arithmetic intensity check

Here's the back-of-the-envelope check that explains why small ops are murder. For each operation, ask: how many FLOPs am I doing per byte of memory I touch?

A modern GPU like an A100 does roughly 312 TFLOPs of fp16 matmul but only has about 2 TB/s of HBM bandwidth. That's a ratio of ~150 FLOPs per byte. If your operation does fewer FLOPs per byte than that, you're memory-bound — full stop. No amount of bigger batches will help if the math isn't there.

A pointwise relu on an fp32 tensor? You read 4 bytes, write 4 bytes, do 1 FLOP. That's 0.125 FLOPs per byte. You are wildly memory-bound. The GPU spends 99% of its time waiting on memory and 1% doing the actual work.

A dense matmul on big enough matrices? Hundreds of FLOPs per byte. Now you're cooking.

Step 3: Fuse the small stuff

The fix for memory-bandwidth-bound code is almost always operator fusion. Instead of running ten separate kernels that each round-trip through HBM, you run one kernel that keeps intermediate values in registers or shared memory.

The easiest win in modern PyTorch is torch.compile:

import torch

model = MyModel().cuda()

# mode='reduce-overhead' uses CUDA graphs to also chip away at launch overhead
# mode='max-autotune' spends more time compiling but can fuse more aggressively
compiled = torch.compile(model, mode='reduce-overhead')

# First call is slow — it's tracing and compiling
_ = compiled(x)

# Subsequent calls hit the cached compiled graph
for _ in range(100):
    out = compiled(x)
Enter fullscreen mode Exit fullscreen mode

I've seen this give 1.5x–3x speedups on transformer-ish workloads with basically zero code changes. Your mileage varies a lot based on how dynamic your shapes are — if every batch has a different sequence length, you'll trigger recompiles and lose most of the win. See the torch.compile docs for the dynamic-shape options.

If you need more control, you can write fused kernels yourself in Triton. For a pointwise chain, it's usually not worth it — torch.compile will fuse those for you. For attention or other patterns with cross-element communication, hand-written kernels (or things like FlashAttention) are still where the big wins live.

Step 4: Crush overhead with CUDA graphs

If your profile shows lots of small gaps between kernels on the GPU timeline, you're overhead-bound. Each kernel launch has fixed CPU-side cost — Python, the dispatcher, CUDA itself. With small kernels, that overhead can be bigger than the kernel runtime.

CUDA graphs let you record a sequence of kernel launches once and replay them as a single submission:

import torch

model = MyModel().cuda()
static_input = torch.randn(32, 3, 224, 224, device='cuda')

# Warm up on a side stream before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for _ in range(3):
        static_output = model(static_input)
torch.cuda.current_stream().wait_stream(s)

# Capture the graph
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    static_output = model(static_input)

# To run: copy new data into static_input in place, then replay
for batch in dataloader:
    static_input.copy_(batch)  # in-place copy, same buffer
    g.replay()
    # static_output now holds the result
Enter fullscreen mode Exit fullscreen mode

The big gotcha: input tensors have to live at the same memory addresses each call. You're reusing buffers, not allocating new ones. That's why we copy_ instead of reassigning. torch.compile(mode='reduce-overhead') does basically this for you under the hood, which is why I usually reach for that first.

Prevention tips

A few habits that have saved me a lot of grief:

  • Profile before optimizing. Always. I've wasted entire afternoons "optimizing" things that were 2% of the runtime.
  • Watch your shapes. Dynamic shapes break torch.compile caches and CUDA graphs. If you can pad to a few bucket sizes, do it.
  • Stop sprinkling .cpu() and .item() calls. Each one forces a sync and stalls the pipeline. If you're doing it inside the training loop for logging, batch it up.
  • Check nvidia-smi while training. If utilization is below ~70%, something's wrong. That's your signal to break out the profiler.
  • Read the assembly when it really matters. For hot kernels, Triton lets you dump the PTX and see what actually got generated. Sometimes the autoscheduler does something silly.

The meta-lesson here is that GPU performance is a first-principles problem, not a vibes-based one. Once you know whether you're starved for FLOPs, bandwidth, or launches, the fix usually becomes obvious. The frustrating part is just resisting the urge to skip the profiling step.

Top comments (0)