DEV Community

Rusheel
Rusheel

Posted on

I built a pre-flight check tool for PyTorch, because silent failures are the worst kind

Last month I was debugging a training run that produced suspiciously bad results. The loop ran fine. No errors. No crashes. Just a model that learned nothing useful.

After three days of debugging I found it: the validation set had samples from the training set. Label leakage. The model had been cheating the entire time and I had no idea.

That was the moment I decided to build preflight.


What is preflight?

preflight is a CLI tool you run before your training loop starts. It catches the silent failures that waste GPU time — the bugs that don't crash Python but quietly ruin your model.

pip install preflight-ml
preflight run --dataloader my_dataloader.py
Enter fullscreen mode Exit fullscreen mode

Output:

preflight — pre-training check report
╭────────────────────────┬──────────┬────────┬──────────────────────────────────────────────────╮
│ Check                  │ Severity │ Status │ Message                                          │
├────────────────────────┼──────────┼────────┼──────────────────────────────────────────────────┤
│ nan_inf_detection      │ FATAL    │ PASS   │ No NaN or Inf values found in 10 sampled batches │
│ normalisation_sanity   │ WARN     │ PASS   │ Normalisation looks reasonable (mean=0.001)      │
│ channel_ordering       │ WARN     │ PASS   │ Channel ordering looks correct (NCHW)            │
│ label_leakage          │ FATAL    │ FAIL   │ Found 12/50 val samples (24%) in train set       │
│ split_sizes            │ INFO     │ PASS   │ train=800 samples, val=200 samples               │
│ vram_estimation        │ WARN     │ PASS   │ Estimated peak VRAM: 2.1 GB / 8.0 GB (26%)      │
│ class_imbalance        │ WARN     │ PASS   │ Class distribution looks balanced                │
│ shape_mismatch         │ FATAL    │ PASS   │ Model accepted input shape (3, 224, 224)         │
│ gradient_check         │ FATAL    │ PASS   │ All gradients look healthy                       │
╰────────────────────────┴──────────┴────────┴──────────────────────────────────────────────────╯

  1 fatal  0 warnings  8 passed

Pre-flight failed. Fix fatal issues before training.
Enter fullscreen mode Exit fullscreen mode

It exits with code 1 on any FATAL failure — which means it blocks CI automatically.


The 10 checks

preflight runs 10 checks grouped into three severity tiers:

FATAL — these stop the run:

  • nan_inf_detection — NaN or Inf values anywhere in sampled batches
  • label_leakage — samples appearing in both train and val sets
  • shape_mismatch — dataset output shape incompatible with model input
  • gradient_check — zero gradients, dead layers, exploding gradients before training

WARN — these flag issues worth fixing:

  • normalisation_sanity — data that looks unnormalised (raw pixel values etc.)
  • channel_ordering — NHWC tensors when PyTorch expects NCHW
  • vram_estimation — estimated peak VRAM exceeds 90% of GPU memory
  • class_imbalance — severe class imbalance beyond a configurable threshold

INFO — these are logged for awareness:

  • split_sizes — empty or degenerate train/val splits
  • duplicate_samples — identical samples within a split

Why not just use pytest?

pytest tests code logic. preflight tests data state.

These are different failure modes at different levels of the stack. A pytest suite can pass completely while your dataset has NaNs, your labels are leaking, and your tensors are in the wrong channel order. preflight fills the gap between "my code runs" and "my training will actually work."


Why not Deepchecks or Great Expectations?

Both are excellent tools. But they're platforms — heavy, general-purpose, and require setup time. preflight is a tool. One pip install, one command, 30 seconds. No config required to get started.

The goal is to make running preflight feel as natural as running pytest before a commit.


How to use it

Basic usage — just a dataloader:

# my_dataloader.py
import torch
from torch.utils.data import DataLoader, TensorDataset

x = torch.randn(200, 3, 224, 224)
y = torch.randint(0, 10, (200,))
dataloader = DataLoader(TensorDataset(x, y), batch_size=32)
Enter fullscreen mode Exit fullscreen mode
preflight run --dataloader my_dataloader.py
Enter fullscreen mode Exit fullscreen mode

Full usage — with model, loss, and val set:

preflight run \
  --dataloader my_dataloader.py \
  --model my_model.py \
  --loss my_loss.py \
  --val-dataloader my_val_dataloader.py
Enter fullscreen mode Exit fullscreen mode

In CI — add to your GitHub Actions workflow:

- name: Install preflight
  run: pip install preflight-ml

- name: Run pre-flight checks
  run: preflight run --dataloader scripts/dataloader.py --format json
Enter fullscreen mode Exit fullscreen mode

With config — add .preflight.toml to your repo:

[thresholds]
imbalance_threshold = 0.05
nan_sample_batches = 20

[checks]
vram_estimation = false
Enter fullscreen mode Exit fullscreen mode

What preflight does NOT do

This is important. preflight is a minimum safety bar, not a guarantee.

  • It does not replace unit tests
  • It does not guarantee a correct model
  • It does not run your full training loop
  • It does not catch every possible failure

Think of it like a pre-flight checklist before a flight. The pilot still needs to fly the plane.


What's next

The roadmap for upcoming releases:

  • --fix flag — auto-patch common issues like channel ordering and normalisation
  • Dataset snapshot + drift detection (preflight diff baseline.json new_data.pt)
  • Full dry-run mode — one batch through model + loss + backward
  • Jupyter magic command (%load_ext preflight)
  • preflight-monai plugin for medical imaging specific checks
  • preflight-sktime plugin for time series checks

Links

If you've ever lost hours to a silent training failure, give it a try. And if you want to contribute — especially new checks — PRs are very welcome. Every check needs a passing test, a failing test, and a fix hint. Check out CONTRIBUTING.md.


Top comments (0)