DEV Community

Marco Rinaldi
Marco Rinaldi

Posted on

Structured channel pruning got our detector under 12ms on a Jetson

TL;DR: Our frame-based defect detector ran at 31ms on a Jetson Orin Nano, and the production line needed 12. Structured channel pruning at 45% plus a three-epoch fine-tune got us to 11.4ms for a 0.6 mAP drop. Unstructured pruning looked beautiful on paper and gave exactly zero real-world speedup, so we deleted it.

So, the thing is, the model was never the problem. The detector hit 0.91 mAP in the lab and everyone was happy. Then we put it on the actual Orin Nano sitting next to the conveyor, and it ran at 31ms per frame. The line moves a stamped part every 12ms. You can imagine how that meeting went.

Let me give you the full picture here. This was a side project away from my usual event-camera work at Prophesee, a plain RGB detector for surface defects on stamped metal panels. Small team, three of us, a hard latency budget and no room for a bigger GPU on the line. Cutting model size was the only lever we had.

The trap I fell into first

I reached for unstructured pruning because the literature loves it. torch.nn.utils.prune, magnitude based, zero out the smallest 60% of weights, retrain. The sparsity numbers were gorgeous. Sixty percent of the parameters gone, mAP barely moved.

Latency on the Jetson: 30.8ms. Same as before, down to the decimal.

Here is why it does nothing. Unstructured pruning keeps every tensor the exact shape it was and writes zeros into it. cuDNN still runs a dense convolution over those zeros. Unless you have a sparse kernel plus hardware that exploits 2:4 structured sparsity (Ampere does, partially, and the Orin Nano's GPU is not the right SKU for it), you bought a smaller checkpoint file and nothing else. The thing was multiplying by zero at full price.

Where the latency actually lives

Structured pruning removes whole filters and channels, so the tensors get physically smaller. A conv layer with 256 output channels becomes 140. FLOPs drop, memory traffic drops, the kernel finishes sooner. Real speedup.

The painful part is dependency tracking. Remove an output channel in one layer and you must remove the matching input channel everywhere it feeds, across residual adds, concats, the lot. Do it by hand and you will spend a Sunday debugging shape mismatches instead of eating lunch in Bologna. torch-pruning traces the whole graph with its DepGraph and prunes coupled layers together.

import torch_pruning as tp

model = load_detector().eval()
example = torch.randn(1, 3, 640, 640)

# group importance by L2 norm of each filter group
imp = tp.importance.GroupNormImportance(p=2)

pruner = tp.pruner.MetaPruner(
    model,
    example,
    importance=imp,
    pruning_ratio=0.45,
    ignored_layers=[model.head.cls],  # leave the classification head alone
)

pruner.step()
macs, params = tp.utils.count_ops_and_params(model, example)
print(f"{macs/1e9:.1f} GFLOPs, {params/1e6:.2f}M params")
Enter fullscreen mode Exit fullscreen mode

After pruning, the model is broken in the accuracy sense. You fine-tune to recover. Three epochs on the same training set pulled mAP back from 0.86 to 0.904. Not free, but cheap compared to retraining from scratch.

The numbers

Approach Params removed FLOPs Jetson latency mAP
Baseline (FP16) 0% 17.2 GFLOPs 31.0 ms 0.910
Unstructured 60% 60% 17.2 GFLOPs (dense) 30.8 ms 0.905
Structured 45% 41% 9.4 GFLOPs 11.4 ms 0.904

Structured pruning removed fewer parameters on paper and delivered all of the speed. That gap between "params removed" and "FLOPs removed" is the whole lesson. FLOPs only fall when the shapes shrink.

Trade-offs and limitations

Pruning ratio has a cliff. At 0.45 we lost 0.006 mAP. At 0.60 the small defect class collapsed and no amount of fine-tuning brought it back, because those filters were genuinely doing work. You have to sweep the ratio per model, there is no universal number.

The head is sensitive. We pinned the classification head with ignored_layers after an early run where pruning it tanked recall on the rare defect type. Pruning is not uniform across a network and treating it as such will cost you.

Speedup is hardware-shaped. Our 9.4 GFLOPs model runs at 11.4ms on the Orin Nano and at 4ms on a desktop 4070. Pruning helps compute-bound layers; if your bottleneck is memory bandwidth or a fat softmax, channel pruning moves the needle less than you hope. Profile first with trtexec or Nsight before you assume FLOPs equal time.

And retraining needs data. We had labels, so recovery was painless. For the handful of ambiguous border crops where annotators disagreed, we sent them to a larger vision model for a second opinion and routed those calls through an AI gateway (we run Bifrost, some teams use LiteLLM) so nobody had to wire up provider SDKs by hand. That kept the labelling loop moving without a separate integration project.

One more honest note. Structured pruning fights quantisation a little. Our INT8 TensorRT export of the pruned model needed re-calibration, and the calibration cache from the dense model was useless. Budget time for that step.

What I would tell past me

Stop optimising the metric that does not move the clock. Sparsity is a vanity number unless your runtime can spend it. Measure latency on the real device, with the real batch size, before and after every change. The espresso machine in our office has more consistent timing than my early benchmark scripts did, and I trusted it more.

If you cannot make the model smaller and faster on the actual target, you do not yet understand where its time goes. Pruning forced me to learn that, layer by layer.

Further Reading

  • torch-pruning (DepGraph) - structured pruning with dependency tracking
  • Pruning Filters for Efficient ConvNets, Li et al. 2017 - the filter-pruning baseline
  • What is the State of Neural Network Pruning?, Blalock et al. 2020 - a sober look at pruning claims
  • NVIDIA TensorRT best practices - profiling and INT8 calibration

Top comments (0)