I've just published RustyASG
to crates.io, a deep learning framework I built from the ground up in Rust. The project's goal was to explore a graph-based architecture, moving away from the eager execution model common in frameworks like PyTorch.
This post is a technical breakdown of its architecture, core design, and the most critical bug I had to fix to make it work correctly.
The Core Idea: Define-then-Run
RustyASG does not compute operations immediately. Instead, every operation constructs a node in an Abstract Semantic Graph (ASG). This graph serves as a complete blueprint of the entire computation before any numbers are crunched.
Once the graph is fully defined, it can be:
- Analyzed: Statically infer tensor shapes to catch dimension mismatches before execution.
- Differentiated: Automatically generate a new graph that calculates the gradients of any node.
- Executed: Run the optimized graph on a backend, such as a CPU or GPU.
This approach provides a high degree of control and enables global optimizations like kernel fusion and static memory planning.
Architecture Overview
The framework's components are strictly separated. This flowchart shows how the pieces fit together:
-
tensor
: A lightweight, symbolic handle to a node in the graph. All operations on it modify the graph. -
asg
: The core data structures that define the graph (Node
,NodeType
). This layer only describes computation. -
autograd
: The reverse-mode automatic differentiation engine. It transforms a forward-pass graph into a new graph that computes gradients. -
runtime
: Contains theBackend
trait and its concrete implementations for CPU (ndarray
) and GPU (wgpu
).
Getting Started: Training a Transformer Block
An example is included to demonstrate a full training loop on a Transformer block.
-
Clone the repository:
git clone https://github.com/Xzdes/RustyAsg.git cd RustyAsg
-
Run the demo (GPU is the default backend):
cargo run --example transformer_demo --release
To use the CPU, set the
use_gpu
flag tofalse
inexamples/transformer_demo.rs
.
The output shows the training loss decreasing, which confirms that the entire stack—from graph construction to backpropagation on the GPU—is functioning correctly.
--- TRAINING LOOP START ---
Epoch: 1 , Loss: 3.081630
Epoch: 2 , Loss: 2.840065
...
Epoch: 15, Loss: 0.982813
--- TRAINING FINISHED IN 1.34s ---
The Hardest Bug: A Flaw in the Gradient Logic
The framework appeared stable until I implemented a gradient checker. The tests compared the analytical gradients produced by the autograd
engine against numerical estimates. A key test, which emulated a LayerNorm
operation, was consistently failing with an error margin of nearly 60%.
The root cause was a subtle but critical flaw in the backpropagation logic for division, specifically when broadcasting was involved.
When a vector like [10, 20]
is divided by a scalar 2
, the result is [5, 10]
. During backpropagation, the gradient for the scalar 2
receives contributions from both elements of the output. Therefore, the incoming gradients must be summed to produce the final, correct gradient for the original scalar.
My autograd implementation for Add
and Subtract
handled this, but I had overlooked it for Divide
.
// The incorrect autograd logic for the divisor 'b'
// grad_b = ... // Calculation was mathematically correct but missed a step.
// self.accumulate_grad(b_id, grad_b)?;
// The corrected logic
let b_shape = self.source_asg.get_node(b_id)?.shape.as_ref().unwrap().clone();
// ... calculate grad_b ...
if b_shape != grad_shape {
// If 'b' was broadcast, its gradient must be summed to match its original shape.
grad_b = self.grad_asg.add_node(None, NodeType::Sum(grad_b));
}
self.accumulate_grad(b_id, grad_b)?;
This conditional sum was the missing piece. Adding it fixed the LayerNorm
test and validated the entire autograd
engine.
What's Next
The foundation is now stable. The immediate roadmap is focused on expanding core functionality:
- Implement gradients for
MatrixMultiply
andSoftmax
. - Add more optimizers, starting with
AdamW
. - Develop a memory-recycling buffer allocator for the
wgpu
backend to improve performance. - Implement model serialization to save and load trained weights.
Contributions are welcome. Please feel free to open Issues or Pull Requests.
You can find the project at:
- GitHub: https://github.com/Xzdes/RustyAsg
- Crates.io: https://crates.io/crates/rustyasg
Top comments (1)
One thing you'll notice if you dive into the source code is that the comments are in Russian. Thinking through complex logic like backpropagation is simply faster and clearer for me in my native language, which was critical for getting the project off the ground.
I plan to translate the key module comments to English as the project matures. If you're interested in helping with this effort, feel free to open an issue or a PR!