DEV Community

Cover image for [Day 8] Pushing Looped Transformers Beyond Addition: OpenMythos on Bracket-Matching Depth
PEPPERCORN
PEPPERCORN

Posted on

[Day 8] Pushing Looped Transformers Beyond Addition: OpenMythos on Bracket-Matching Depth

[Day 8] Pushing Looped Transformers Beyond Addition: OpenMythos on Bracket-Matching Depth

Intro

Day 8!

A direct follow-up to Day 7: same OpenMythos-style mini model (3.4M params), same training pipeline, one task change — multi-digit addition swapped for nested-bracket parsing. The goal was to ask two follow-up questions Day 7 left open:

  1. Does the "training-time loop count is the peak" finding generalize across tasks?
  2. If we increase the structural complexity of the input (deeper nesting), does inference-time loop count start to matter?

Tools used: my home AI machine (DGX Spark, GB10) + OpenMythos (PyTorch reconstruction of the rumored Claude Mythos architecture) + synthetic bracket sequences.


Today's setup

Why bracket matching?

Day 7's task was 2-5 digit addition. Addition tests "carry propagation from low to high digit" — a fundamentally local, left-to-right state update. To probe whether looped depth helps with a different kind of structural reasoning, I wanted a task where:

  • The output depends on left-to-right state tracking (rules out attention-based global aggregation shortcuts).
  • The task admits an explicit notion of depth I can vary as a controlled difficulty knob.

Bracket matching fits both. The standard linear-time algorithm is push-on-open / pop-on-close with a stack. A model that has internalized that algorithm should scale gracefully with depth — and one that hasn't will visibly fall over.

Task: first-break-position prediction

Input: a string of ( ) [ ] { } characters, terminated by =.
Output: the left-most position at which the bracket structure breaks, as 2 digits, terminated by $. If the sequence is balanced, output --$.

Examples:

((()))=          → --$       (balanced)
([{}])=          → --$       (balanced)
([)]=            → 02$       ()` at position 2 doesn't match preceding `[`)
(()(=            → 04$       (stack non-empty at end of string, position = len)
))=              → 00$       (close on empty stack at position 0)
Enter fullscreen mode Exit fullscreen mode

The "break position" is defined by a stack parser scanning left-to-right:

  1. Close bracket whose type ≠ stack top → return that close position.
  2. Close bracket on empty stack → return that close position.
  3. End of string with non-empty stack → return len(s).
  4. Otherwise balanced → return -1 (output --).

Why not just binary balanced / imbalanced?

That was the original plan. A first smoke run with T/F output saturated to 100% accuracy across all depths (up to 10) by step 4,000. There are too many shortcut signals — length parity, open/close count, etc. — for a transformer to learn the actual stack algorithm.

The first-break-position output forces the model to commit to a specific character position, which can only be answered by tracking state left-to-right. After this change, smoke results at 5,000 steps showed clean depth-dependent difficulty (d=2: 100%, d=20: 71%) and the loss had room to keep dropping. That's the signal I needed to study loop-count behavior meaningfully.

Difficulty knob: depth

I trained and evaluated across depths {2, 4, 6, 8, 10, 12, 16, 20}, with pair count capped at min(2 * depth, 50) so the 2-digit position output stays in range. Balanced and imbalanced sequences mixed 50/50; imbalanced sequences generated by deleting a close (30%), deleting an open (30%), or substituting a bracket (40%).

Architectural changes from Day 7

Minimal — only what the new vocab and longer sequences required:

Day 7 (addition) Day 8 (brackets)
vocab_size 16 20
max_seq_len 32 128
max_loop_iters (train) 4 4
Difficulty axis 2-5 digits depth 2-20
Answer tokens 1-6 (digits + $) 2 + $
Total params 3.39M 3.39M

Same MythosConfig template otherwise. Same hyperparameters (AdamW, max LR 3e-4, warmup 2000, cosine decay, 30k steps, fp32, 4 seeds in parallel).

Headline finding

  • The Day 7 "peak at training loop count" finding generalizes. With training max_loop_iters=4, accuracy peaks at exactly T=4 again, and decays in both directions — including at every depth I tested.
  • But the peak height is much lower. Best accuracy was 66% at depth 2; depth 20 caps at ~36%. Day 7 hit 100% at d=5; brackets at the same parameter budget plateau dozens of points short.
  • Inference-time loop extrapolation does NOT improve deep-nesting performance. The hypothesis "deeper inputs benefit from more loops" did not reproduce — T>4 hurts at every depth, just as in Day 7.
  • Fixed-point reproduced, slightly later. Cosine similarity between consecutive hidden states reaches ~0.95 by T=3 and ~0.99 by T=4 — a step or two later than addition (which got there by T=2).

🪢 The task in pictures

Input:  ( ( [ ) ] ) =
Pos:    0 1 2 3 4 5

stack walk:
  pos 0: '(' → push '('             stack: ( 
  pos 1: '(' → push '('             stack: ( (
  pos 2: '[' → push '['             stack: ( ( [
  pos 3: ')' → top is '[', mismatch!  → first break at position 3

Expected output: 03$
Enter fullscreen mode Exit fullscreen mode

The interesting thing about this task vs. addition: the answer can be anywhere from 0 to ~40 depending on the input, and the model has to commit to a specific integer. There's no global-aggregation shortcut — you have to walk left-to-right and remember what you've seen.


🔧 Pipeline

OpenMythos tiny (3.4M params, same as Day 7 modulo vocab + max_seq_len)
  ↓
Train 4 seeds in parallel, 30k steps, fp32 on DGX Spark (GB10)
  ↓
Experiment A: greedy autoregressive accuracy
              loops ∈ {1, 2, 4, 8, 16, 32}  ×  depth ∈ {2, 4, 6, 8, 10, 12, 16, 20}
  ↓
Experiment B: cosine similarity between consecutive hidden states
              ⇒ does the recurrent block reach a fixed-point?
              ⇒ does the fixed-point timing depend on depth?
  ↓
Compare against Day 7 (digits) along the same axes
Enter fullscreen mode Exit fullscreen mode

Training throughput note (vs Day 7)

Day 7's 4-seed parallel training was fast because max_seq_len=32 left the GPU underutilized per process. With max_seq_len=128, a single process already saturates the GB10 — 4-seed parallel drops per-process throughput from ~60K tok/s to ~12.8K tok/s (a -79% per-process penalty). Aggregate parallel throughput is actually ~15% slower than sequential 4-seed.

I let it run in parallel anyway because it was overnight and I had no other DGX usage scheduled. Worth noting for anyone planning similar replications: longer sequences kill the "free" benefit of multi-seed parallelism on a single GPU.

GPU draw stayed at 51W / 72°C / 95% utilization throughout — comfortable enough to leave running.


📊 Results

Experiment A: accuracy heatmap

accuracy heatmap of bracket-matching across loop counts and depths

Mean exact-match accuracy across 4 seeds, 500 eval samples per condition:

Inference loops d=2 d=4 d=6 d=8 d=10 d=12 d=16 d=20
1 0.11 0.05 0.03 0.02 0.01 0.01 0.01 0.02
2 0.32 0.20 0.13 0.08 0.08 0.08 0.07 0.07
4 (train) 0.66 0.56 0.50 0.45 0.44 0.41 0.41 0.36
8 0.58 0.56 0.51 0.47 0.46 0.44 0.39 0.34
16 0.55 0.51 0.44 0.41 0.40 0.38 0.36 0.32
32 0.55 0.48 0.42 0.40 0.39 0.37 0.36 0.31

Observations:

  • Peak at T=4 across every depth column. Day 7's "loops help only in a narrow window centered on training" finding generalizes: no depth I tested has its best accuracy at T≠4.
  • Depth scaling is graceful but the ceiling is low. Going from d=2 to d=20 at T=4, accuracy degrades smoothly (0.66 → 0.36), but the absolute numbers stay far from saturation.
  • The "deeper input ⇒ more loops" hypothesis does not hold. I'd hoped to see T=8 or T=16 begin to dominate at d=20, indicating inference-time scaling could rescue depth. Instead, every depth column peaks at T=4 and decays — same shape as Day 7's digit-count columns, just stretched lower.
  • T=8 is unusually competitive at mid-depths. At d=4 through d=10, T=8 is within ~1pt of T=4 (sometimes slightly higher). Possibly two adjacent settings of test-time depth around the training value are both near-optimal.

Experiment B: fixed-point analysis

fixed-point cosine similarity curve across loop steps and depths

Mean cosine similarity between consecutive hidden states cos(h_t, h_{t-1}) measured at the first-answer-token position, averaged across 4 seeds, 200 samples per depth:

t d=2 d=4 d=8 d=12 d=16 d=20
1 0.85 0.89 0.92 0.92 0.93 0.91
2 0.91 0.91 0.94 0.95 0.95 0.97
3 0.94 0.94 0.92 0.92 0.94 0.95
4 0.95 0.97 0.96 0.95 0.93 0.93
8 0.998 0.995 0.998 0.996 0.996 0.992
16 0.9994 0.9996 0.9989 0.9985 0.9976 0.9979
32 0.9998 0.9998 0.9998 0.9997 0.9995 0.9996

Three things to note:

  1. Fixed-point timing is slightly later than Day 7. Day 7 reached ~0.95 by T=2; brackets reach ~0.95 at T=3 and ~0.99 at T=4. About one extra loop step on this metric. Possibly the more complex left-to-right state needs a beat longer to settle.

  2. Depth dependence is small. d=20 traces almost on top of d=2, again echoing Day 7 (where digit-count had only marginal effect on fixed-point timing). "Harder problem ⇒ slower fixed-point" did not appear.

  3. Hidden state stops moving by T=4 (cosine ~0.99) while accuracy starts decaying. Same paradox as Day 7: extra loops are computation without information. Either the late-loop perturbations are small but logit-relevant drift away from a converged answer, or this is purely a distribution-shift artifact of training only at T=4.

Comparison with Day 7

Axis Day 7 (addition) Day 8 (brackets)
Loop-count peak at T=train (=4) Yes Yes
Best accuracy at peak 100% (all digits) 66% (d=2), 36% (d=20)
Inference-time loop extrapolation Hurts Hurts
Cosine fixed-point arrival ~T=2 ~T=3
Depth/digit dependence on fixed-point Small Small
Training dynamics Grokking (sudden phase transition) Smooth slow climb

Day 8 reproduces all the qualitative findings of Day 7. What changes is the quantitative ceiling: at the same parameter budget and the same training compute, structure-tracking caps far below saturation while addition saturates.


💡 Tying back to the three perspectives

Day 7 tested looped transformers against three published views:

  • Saunshi et al. — loops can match deeper fixed-depth networks on algorithmic tasks
  • Geiping et al. (Huginn) — at scale, extra loops give marginal gains
  • Micheal Bee — loops plateau early at small scale (T=2 fixed-point)

Day 8 adds three more data points to the picture:

  1. The "peak at training loop count" pattern persists across qualitatively different algorithmic tasks (addition vs. bracket parsing). This is consistent with Saunshi's framing but argues against naive depth-extrapolation at inference.

  2. The fixed-point arrives at slightly different times for different tasks. Bee's "T=2" appears to be a property of the specific task and training recipe, not a universal property of looped transformers. Brackets need ~T=3-4 to plateau, addition needs ~T=2.

  3. Task structural complexity matters more than loop count. At a fixed budget, the ceiling on accuracy is set by something else (model capacity? loss landscape? data efficiency?), not by the number of inference loops. Adding more loops can't compensate.

A useful refinement: looped transformers carry compute up to a depth bounded by the task's algorithmic complexity and the model's expressive capacity. Beyond that, the hidden state stops moving meaningfully and additional loops are computation without information. Day 7 showed this for a task within capacity (addition saturates); Day 8 shows it for a task that bumps against capacity (bracket parsing caps short).


🛠️ Technical details

Smoke history (why the task definition changed)

Initial smoke: balanced/imbalanced binary classification, depths 2-10.
Result: 100% accuracy across all depths by step 4,000.
Diagnosis: too many shortcut signals (length parity, open/close count) for the model to learn the stack algorithm — even with mutations that should defeat counting shortcuts. The 2-bit output gives the model no incentive to track position-by-position state.

Second smoke: first-break-position output, depths 2-20.
Result at 5,000 steps: d=2 100%, d=20 71%, with loss still trending down (0.32 → still falling).
Diagnosis: depth-dependent difficulty visible, room to scale training to expose loop-count effects.

Lesson worth recording: output information density matters as much as task structure for studying loop behavior. A binary classifier with global-aggregation shortcuts is a weak probe of recurrent depth.

Config and hyperparameters

MythosConfig(
    vocab_size=20,         # 6 brackets + '=' + '$' + space + '-' + '0'-'9'
    dim=256,
    n_heads=8,
    n_kv_heads=2,          # GQA
    max_seq_len=128,       # Day 7 was 32
    max_loop_iters=4,
    prelude_layers=1,
    coda_layers=1,
    attn_type="gqa",
    n_experts=4,           # MoE FFN inside recurrent block
    n_shared_experts=1,
    n_experts_per_tok=2,
    expert_dim=512,
    lora_rank=8,
    rope_theta=10000.0,
)
Enter fullscreen mode Exit fullscreen mode

Total parameters: 3,394,338 (~3.4M, matches Day 7 to within rounding).

Training:

  • Optimizer: AdamW, betas (0.9, 0.95), wd 0.1
  • LR: max 3e-4, warmup 2000 steps, cosine decay to 1e-5
  • Grad clip: 1.0
  • Batch size: 128
  • Max steps: 30000
  • dtype: fp32 (same RoPE-complex-buffer reason as Day 7)
  • 4 seeds {0, 1, 2, 3} in parallel

Data generation

On-the-fly synthetic. For each sample:

  • Sample depth d ∈ {2, 4, 6, 8, 10, 12, 16, 20} uniformly
  • Sample pair count n_pairs ~ U[max(1, d-1), min(2*d, 50)]
  • Generate balanced parenthesization (random bracket types, nested or sequential)
  • With prob 0.5, apply a mutation: delete close (30%), delete open (30%), substitute (40%)
  • Compute first-break position with the stack parser; format output

Loss is applied only at positions following = (i.e., on the 2-digit answer + $).

Evaluation

  • Experiment A: greedy autoregressive generation, exact 3-token match (position digits + $). 500 samples per (seed, n_loops, depth).
  • Experiment B: re-implementation of OpenMythos forward to expose per-loop hidden states. Cosine similarity at the first answer-token position. 200 samples per (seed, depth), 32 loop iterations.

What I'd want to try next

  • Increase training-time loop count and re-measure. Does the peak track with training depth (suggesting it's purely a distribution-shift artifact) or does extrapolation stay broken?
  • Scale model dim while keeping loops fixed. Does a 10x bigger model break through the ~66% / ~36% bracket ceiling, or does the structure-tracking task itself need a different inductive bias?
  • Mix tasks in training. Train on addition + brackets jointly and see if there's interference or transfer.
  • Inject explicit halting (ACT). Let the model choose how many loops per token. Does it match the empirical optimum or settle elsewhere?

References

Training and evaluation scripts: https://github.com/SAETAG/dgx-100-experiments/tree/main/days/day08-bracket-matching/scripts.


Tomorrow: Day 9

Switching gears to something much more personal — handing private chat data to a local model and seeing what it surfaces…!

100ExperimentsWithDGX #LocalLLM

Top comments (0)