Part 2: Scaling to MNIST — The Challenge of Abstracting Math in Rust
After the first success with Adaline, I felt confident. I thought, "How hard can it be to scale this to a Multi-Layer Perceptron (MLP) for MNIST?"
The answer: Hard enough to make me delete everything.
The "AI Trap" and the Great Reset
I've got fully working program with all possible optimisations for 3-4 days after I started, hoverer, looking at the source code, I felt like a stranger to my own logic. I had relied too much on AI for the Rust boilerplate. I had a "working" project, but I wasn't truly learning the mechanics of Rust's ownership or the underlying math. Of course it was a great learning experience - because I've quickly moved from zero knowledge to hero about multi layer perceptron backdrop + dealing with MNIST dataset. But things did not aligned in my head fully...to quick, minimum mental effort from my side to write the new knowledges in stone.
So, I did the unthinkable: I deleted the repository. Starting over, line by line by myself, was the best decision I made. It forced me to figure out every peace of of the program by myself.
Phase 1: The Transition to ndarray
My initial attempt used only Vec<f32>. It worked for Adaline, but for MNIST (60,000 images, 784 pixels each), it was a nightmare.
On my Mac M2 Pro, each epoch took around 4 seconds. While it wasn't "slow" for a small test, it became a bottleneck as I started experimenting with hyperparameters. Switching to the ndarray crate was essential for performance, but it triggered a cascading refactor.
In Rust, your type system is your contract. Changing Vec to Array2 meant rewriting everything from data loaders to weight initializers. I used this as an opportunity to restart with a cleaner, modular structure, avoiding AI-assisted boilerplate to ensure I fully understood every architectural decision.
Phase 2: The "Refactoring Trap"
By late February, I had a working MVP. The code was "dirty" and monolithic, but the math was correct, and the network was learning. Then, I decided to refactor it for modularity.
And that’s when the "house of cards" collapsed.
In standard Web Development, we are used to juggling modules and applying design patterns. But with Neural Networks, where the core logic is a dense wall of matrix calculus, abstraction is a double-edged sword. What worked as a single "wall of math" became difficult to follow once broken into separate modules. I realized that in math-heavy code, a "clean" architecture can sometimes break your mental map of the algorithm, making it harder to track how tensors flow through the layers.
Phase 3: Results & The "Visual" Architecture
To verify my refactored logic, I focused on concrete metrics and visualization. I chose a specific architecture to satisfy both the math and my curiosity:
- Input Layer: 784 (28x28 pixels)
- Hidden Layer: 36 neurons (Chosen specifically because 36 = 6x6, allowing for a perfect square heatmap visualization)
- Output Layer: 10 (digits 0-9)
The Benchmarks (on Mac M2 Pro):
| Metric | Result |
|---|---|
| Final Test Accuracy | 95.56% (9556/10000 correct) |
| Training Time (per epoch) |
~0.65s (with ndarray) |
| Initial Time (raw vectors) | ~4.0s |
| Total Epochs | 30 |
| Hyperparameters | Batch Size: 10, Learning Rate: 3.0 |
Phase 4: Seeing the Brain Work
I didn't just want to see numbers; I wanted to see what the network was learning, so I built a terminal-based visualizer for MNIST digits and implemented weight heatmaps.
By auto-detecting grid dimensions (like reassembling the 36 hidden layer weights back into a 6x6 grid), I could finally validate my theories on how different layers respond to features. Seeing the weights evolve from random noise into recognizable patterns confirmed that my modular architecture hadn't corrupted the mathematical logic.
Key Takeaways:
-
Performance Gains: Moving to
ndarrayon M2 Pro cut training time by ~6x (from 4s down to 0.65s per epoch). - Abstractions are Expensive: In math-heavy code, "cleaner" isn't always better if it breaks your mental map of the matrix flow.
- Visualize for Validation: Choosing architecture sizes (like 36 neurons for a 6x6 grid) specifically for visualization is a powerful debugging strategy.
Deep Dive into the Math
Because this project was about learning the "how" and "why," I didn't stop at the code. I’ve documented the entire mathematical journey in the repository's README.
If you're curious about the calculus behind backpropagation, weight initialization, or the specific program workflow, you'll find a massive breakdown there. I’ve included detailed explanations of matrix transformations and full instructions on how to run the training yourself.
What's Next?
Now that the modular structure is stable and the MLP is hitting 95%+ accuracy, I'm focusing on:
- Weight persistence: Saving the trained "brain" to disk.
- Modularity: Making activation functions truly swappable.
- Optimization: Exploring BLAS acceleration for even faster operations.
Check the full source code & math breakdown here: Nutscracker87/rust-mnist
Let's connect on LinkedIn: Smirnov Nikita
Top comments (0)