DEV Community

Muhammed Shafin P
Muhammed Shafin P

Posted on

Thinking Transformers: A Transformer That Reasons Before It Speaksking Transformer

Most neural language models work the same way: take in a sequence of tokens, run one forward pass, and spit out a prediction. It's fast, it's well understood, and for many tasks it works well. But there's something fundamentally rushed about it — the model has exactly one shot to "think" before it answers, no matter how hard the problem is.

Thinking Transformers takes a different approach. Before producing any output, the model runs its hidden states through the full transformer stack multiple times — these are called think steps. Each pass lets the model refine its internal representation, catch contradictions, and build up a richer picture of the input before committing to an answer. The number of think steps is configurable, and crucially, every one of them is part of the computation graph — training uses full Backpropagation Through Time (BPTT) across all think steps and all layers simultaneously. The model doesn't just learn to predict; it learns how to think.

Alongside the reasoning loop, the architecture includes a small gated memory bank — a set of persistent slots that are read from and written to at each think step. This gives the model a lightweight working memory that can carry context forward across iterations, something a standard single-pass transformer simply cannot do.

The whole thing is built from scratch in plain C with no external dependencies beyond libm. It compiles into a shared library (transformer.so on Linux/macOS, transformer.dll on Windows) in a single GCC command. The Python layer wraps this library via ctypes, exposing a clean, minimal API through two classes: TransformerConfig, which holds all architecture hyperparameters (vocab size, embedding dimension, number of heads, feed-forward width, layers, sequence length, think steps, memory slots), and ThinkingTransformer, which is the model itself.

The Python API is deliberately straightforward. You call model.train_step(tokens, targets, lr) for a single training iteration — it handles zeroing gradients, running the full BPTT backward pass, and applying an Adam update internally. For more control, zero_grad(), backward(), and step() are all exposed separately. Inference is equally simple: model.generate(prompt, max_new_tokens) does greedy decoding, while model.generate_with_thinking(prompt) wraps the prompt in explicit THINK and PLAN tokens and returns not just the output tokens but the full reasoning structure and logits. Checkpoints save and load cleanly with model.save(path) and model.load(path).

Everything — architecture, training, and the reasoning loop — is open source and available at https://github.com/hejhdiss/Thinking-Transformers. The project is licensed under GNU GPL v3.

Top comments (0)