<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Asish Kumar Dalal</title>
    <description>The latest articles on DEV Community by Asish Kumar Dalal (@asishdalal).</description>
    <link>https://dev.to/asishdalal</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3897215%2Feadf1c52-fd75-4776-8719-611188f8fdab.png</url>
      <title>DEV Community: Asish Kumar Dalal</title>
      <link>https://dev.to/asishdalal</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/asishdalal"/>
    <language>en</language>
    <item>
      <title>Teaching Small Language Models to Remember: Giving LLMs a Notebook with Differentiable Neural Computers</title>
      <dc:creator>Asish Kumar Dalal</dc:creator>
      <pubDate>Sat, 25 Apr 2026 08:31:37 +0000</pubDate>
      <link>https://dev.to/asishdalal/teaching-small-language-models-to-remember-giving-llms-a-notebook-with-differentiable-neural-42dp</link>
      <guid>https://dev.to/asishdalal/teaching-small-language-models-to-remember-giving-llms-a-notebook-with-differentiable-neural-42dp</guid>
      <description>&lt;blockquote&gt;
&lt;p&gt;&lt;em&gt;"Large models memorize the world in their weights. Small models need a notepad."&lt;/em&gt;&lt;/p&gt;
&lt;/blockquote&gt;




&lt;h2&gt;
  
  
  The Problem: Small Models Forget Facts
&lt;/h2&gt;

&lt;p&gt;Large Language Models (LLMs) like GPT-4 are remarkably good at recalling facts — "Delhi is the capital of India," "Einstein developed the theory of relativity" — because they have billions of parameters acting as a massive, compressed knowledge store. The model bakes facts into weights during pre-training, and retrieval is implicit in the forward pass.&lt;/p&gt;

&lt;p&gt;But what happens when you shrink the model?&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Small Language Models (SLMs)&lt;/strong&gt; — the kind you can actually run on a laptop or edge device — have far fewer parameters. There simply isn't enough capacity to reliably encode factual associations. They can handle grammar, style, and short-range reasoning reasonably well, but ask them a factual question and they hallucinate, hedge, or go blank.&lt;/p&gt;

&lt;p&gt;The parametric memory paradigm breaks down at small scale.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Insight
&lt;/h3&gt;

&lt;p&gt;Humans don't store all their knowledge in their neurons alone. We use &lt;strong&gt;external memory&lt;/strong&gt; — notebooks, calendars, books, sticky notes. We offload facts to the environment and look them up when needed. The neural machinery handles &lt;em&gt;reasoning&lt;/em&gt;; the notepad handles &lt;em&gt;retrieval&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;What if we gave a small language model an explicit, learnable notepad?&lt;/p&gt;

&lt;p&gt;That's precisely what a &lt;strong&gt;Differentiable Neural Computer (DNC)&lt;/strong&gt; does.&lt;/p&gt;




&lt;h2&gt;
  
  
  What Is a DNC?
&lt;/h2&gt;

&lt;p&gt;A Differentiable Neural Computer, introduced by DeepMind in 2016, augments a neural network controller with an &lt;strong&gt;external memory matrix&lt;/strong&gt; — a structured, differentiable store that the network can read from and write to via learned attention mechanisms.&lt;/p&gt;

&lt;p&gt;Think of it as RAM for a neural network.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Memory Matrix  M  ∈  ℝ^(N × W)
                   │
          N = number of memory slots (rows)
          W = width of each slot (columns)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The controller (in our case, a small GPT-2) interacts with this memory through &lt;strong&gt;soft, differentiable read and write heads&lt;/strong&gt; — so the whole system is end-to-end trainable with backpropagation.&lt;/p&gt;

&lt;p&gt;Unlike a hash map or database, the DNC doesn't look up memory by exact key. It uses &lt;strong&gt;content-based addressing&lt;/strong&gt; — cosine similarity between a query key and stored vectors — blended with &lt;strong&gt;usage-based allocation&lt;/strong&gt; to decide where to write new information.&lt;/p&gt;




&lt;h2&gt;
  
  
  Architecture: GPT-2 + DNC Memory
&lt;/h2&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ft8ah395ydu3tv7xuvprt.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ft8ah395ydu3tv7xuvprt.png" alt=" " width="800" height="572"&gt;&lt;/a&gt;&lt;br&gt;
The full model layers two components:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;                    ┌─────────────────────────────┐
                    │       GPT-2 Backbone        │
                    │  (Masked Self-Attention +   │
                    │   Feed-Forward Layers)      │
                    └──────────────┬──────────────┘
                                   │  hidden state h_t  (B, D)
                                   ▼
                    ┌─────────────────────────────┐
                    │        DNC Memory           │
                    │  ┌─────────────────────┐    │
                    │  │  M ∈ ℝ^(N × W)      │    │  ← external RAM
                    │  └─────────────────────┘    │
                    │   write → read → update     │
                    └──────────────┬──────────────┘
                                   │  read_vec  (B, R*W)
                                   ▼
                         read_proj → h_t + read_vec
                                   │
                                   ▼
                              LM Head → logits
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;At each time step &lt;code&gt;t&lt;/code&gt;, the GPT-2 hidden state &lt;code&gt;h_t&lt;/code&gt; is used to:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;
&lt;strong&gt;Write&lt;/strong&gt; new information into memory&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Read&lt;/strong&gt; relevant information back out&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Fuse&lt;/strong&gt; the read vector with &lt;code&gt;h_t&lt;/code&gt; before projecting to vocabulary logits&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The memory persists across time steps within a sequence, making it a form of &lt;strong&gt;working memory&lt;/strong&gt; — information written at step 3 can be retrieved at step 47.&lt;/p&gt;




&lt;h2&gt;
  
  
  The Memory Module: Read &amp;amp; Write Mechanics
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Memory State
&lt;/h3&gt;

&lt;p&gt;The memory at any step is a matrix &lt;code&gt;M ∈ ℝ^(B × N × W)&lt;/code&gt; — a batch of &lt;code&gt;N&lt;/code&gt; slots, each a &lt;code&gt;W&lt;/code&gt;-dimensional vector. A usage vector &lt;code&gt;u ∈ ℝ^(B × N)&lt;/code&gt; tracks how much each slot has been written to.&lt;/p&gt;

&lt;h3&gt;
  
  
  Projections from Hidden State
&lt;/h3&gt;

&lt;p&gt;Given the controller hidden state &lt;code&gt;h_t ∈ ℝ^(B × D)&lt;/code&gt;, the memory module computes:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Projection&lt;/th&gt;
&lt;th&gt;Shape&lt;/th&gt;
&lt;th&gt;Purpose&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;write_key&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;(B, W)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;
&lt;em&gt;Where&lt;/em&gt; to write (content addressing)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;write_vec&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;(B, W)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;
&lt;em&gt;What&lt;/em&gt; to write&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;erase_vec&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;(B, W)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;What to erase before writing (sigmoid-gated)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;write_gate&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;(B, 1)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;
&lt;em&gt;How much&lt;/em&gt; to write (0 = skip, 1 = full write)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;read_keys&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;&lt;code&gt;(B, R, W)&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Where to read from (&lt;code&gt;R&lt;/code&gt; read heads)&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;h3&gt;
  
  
  Write Weighting
&lt;/h3&gt;

&lt;p&gt;The write address &lt;code&gt;w_write ∈ ℝ^(B × N)&lt;/code&gt; is a soft attention distribution over slots:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;w_content = softmax( cosine(write_key, M) × τ )
w_alloc   = softmax( (1 − u) × τ )

w_write   = 0.5 × w_content + 0.5 × w_alloc
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Content addressing&lt;/strong&gt; (&lt;code&gt;w_content&lt;/code&gt;): write near slots whose content resembles the current write key — useful for &lt;em&gt;updating&lt;/em&gt; existing facts.&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Allocation&lt;/strong&gt; (&lt;code&gt;w_alloc&lt;/code&gt;): prefer &lt;em&gt;less-used&lt;/em&gt; slots — useful for storing &lt;em&gt;new&lt;/em&gt; facts without overwriting old ones.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;code&gt;τ&lt;/code&gt; is a learned temperature parameter that sharpens or softens the distribution.&lt;/p&gt;

&lt;h3&gt;
  
  
  Write Operation
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;M_new = M × (1 − w_write ⊗ erase_vec) + w_write ⊗ write_vec

M_out = M + write_gate × (M_new − M)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The &lt;code&gt;write_gate&lt;/code&gt; is the key knob:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;write_gate ≈ 0  →  memory unchanged  (model relies on parametric knowledge)
write_gate ≈ 1  →  full write        (model externalizes knowledge)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This gate is learned entirely from data. The model discovers &lt;em&gt;when&lt;/em&gt; it's worth writing.&lt;/p&gt;

&lt;h3&gt;
  
  
  Read Operation
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;w_read  = softmax( read_keys · M^T × τ )   ∈ ℝ^(B × R × N)
read_vec = w_read · M                       ∈ ℝ^(B × R × W)
         → reshape to (B, R*W)
         → projected back to (B, D) via read_proj
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;code&gt;R&lt;/code&gt; read heads allow the model to simultaneously query &lt;code&gt;R&lt;/code&gt; different "topics" from memory.&lt;/p&gt;

&lt;h3&gt;
  
  
  State Update
&lt;/h3&gt;

&lt;p&gt;Usage is updated after each write so the allocator tracks which slots are "full":&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;usage_new&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;usage&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;usage&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;w_write&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;detach&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The &lt;code&gt;.detach()&lt;/code&gt; prevents gradients from flowing back through the usage signal — it's a bookkeeping variable, not a learned one.&lt;/p&gt;




&lt;h2&gt;
  
  
  The Write Gate: Knowing When to Remember
&lt;/h2&gt;

&lt;p&gt;The write gate is the most interpretable component of the whole system. After training, you can run &lt;code&gt;inspect_writes()&lt;/code&gt; and visualize per-token gate activations:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;Token                  Gate   bar
────────────────────────────────────────────────
Albert                 0.821  ████████████████████████
Einstein               0.904  ███████████████████████████
was                    0.112  ███
born                   0.287  ████████
in                     0.094  ██
1879                   0.756  ██████████████████████
in                     0.071  ██
Ulm                    0.683  ████████████████████
He                     0.143  ████
developed              0.201  ██████
the                    0.058  █
theory                 0.388  ███████████
of                     0.062  █
relativity             0.712  █████████████████████
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;The model learns to write on content-bearing tokens&lt;/strong&gt; (proper nouns, dates, key concepts) and skip function words. Nobody taught it this — it emerged from the loss functions.&lt;/p&gt;




&lt;h2&gt;
  
  
  Loss Functions
&lt;/h2&gt;

&lt;p&gt;Training uses three losses summed together:&lt;/p&gt;

&lt;h3&gt;
  
  
  1. Language Modelling Loss (Cross-Entropy)
&lt;/h3&gt;

&lt;p&gt;The standard next-token prediction loss:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;L_lm = CrossEntropy(logits[:, :-1], input_ids[:, 1:])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is the primary loss. The model must still predict the next token correctly.&lt;/p&gt;

&lt;h3&gt;
  
  
  2. Routing Loss
&lt;/h3&gt;

&lt;p&gt;This loss asks: &lt;em&gt;when the write gate is high, does memory actually change the prediction?&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;If the model writes to memory but the output distribution looks identical to the no-memory baseline, that write was pointless. The routing loss penalises this:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;kl&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;KL&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt; &lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p_no_mem&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;||&lt;/span&gt; &lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p_mem&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;detach&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;L_routing&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;gate&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;kl&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The KL divergence between the memory model and a frozen no-memory baseline is computed per token. Multiplied by the gate and negated, this loss:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Rewards&lt;/strong&gt; high gates when memory &lt;em&gt;changes&lt;/em&gt; the prediction (high KL)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Punishes&lt;/strong&gt; high gates when memory &lt;em&gt;doesn't matter&lt;/em&gt; (low KL → wasted write)&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;The &lt;code&gt;.detach()&lt;/code&gt; on the KL ensures gradients only flow through the gate, not the no-memory logits.&lt;/p&gt;

&lt;h3&gt;
  
  
  3. Entropy Loss (Write Sparsity)
&lt;/h3&gt;

&lt;p&gt;A diffuse write weighting — spreading activation uniformly across all &lt;code&gt;N&lt;/code&gt; slots — is wasteful. It's like writing one word across every page of your notebook instead of a single page.&lt;/p&gt;

&lt;p&gt;The entropy loss encourages sharp, decisive writes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;H&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w_writes&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w_writes&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mf"&gt;1e-8&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;()).&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;L_entropy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;H&lt;/span&gt;   &lt;span class="c1"&gt;# minimized during training
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Low entropy → sparse write attention → the model commits to specific slots.&lt;/p&gt;

&lt;h3&gt;
  
  
  Total Loss
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;L&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;L_lm&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;λ_r&lt;/span&gt; &lt;span class="err"&gt;·&lt;/span&gt; &lt;span class="n"&gt;L_routing&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;λ_e&lt;/span&gt; &lt;span class="err"&gt;·&lt;/span&gt; &lt;span class="n"&gt;L_entropy&lt;/span&gt;

&lt;span class="c1"&gt;# defaults: λ_r = 0.1,  λ_e = 0.05
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The auxiliary losses are kept small relative to &lt;code&gt;L_lm&lt;/code&gt; so language modelling remains the primary objective. The routing and entropy terms act as &lt;strong&gt;structural regularizers&lt;/strong&gt; that shape &lt;em&gt;how&lt;/em&gt; the memory is used, not just whether the model gets tokens right.&lt;/p&gt;




&lt;h2&gt;
  
  
  Code Walkthrough
&lt;/h2&gt;

&lt;h3&gt;
  
  
  DNCMemory Module
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;DNCMemory&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mem_slots&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mem_width&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_reads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;controller_size&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mem_slots&lt;/span&gt;   &lt;span class="c1"&gt;# number of memory rows
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;W&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;mem_width&lt;/span&gt;   &lt;span class="c1"&gt;# width of each row
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;R&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;num_reads&lt;/span&gt;   &lt;span class="c1"&gt;# number of read heads
&lt;/span&gt;
        &lt;span class="c1"&gt;# All projections from controller hidden state
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;write_key_proj&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;controller_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mem_width&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;write_vec_proj&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;controller_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mem_width&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;erase_vec_proj&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;controller_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mem_width&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;write_gate_proj&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;controller_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;read_key_proj&lt;/span&gt;   &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;controller_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;mem_width&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;num_reads&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;temp&lt;/span&gt;            &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Parameter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;  &lt;span class="c1"&gt;# learned sharpness
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  DNCLLM Forward Pass
&lt;/h3&gt;

&lt;p&gt;The key loop — stepping through time and interleaving memory reads/writes with transformer hidden states:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;memory&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;usage&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# Run all tokens through GPT-2 in parallel (causal masking handles ordering)
&lt;/span&gt;    &lt;span class="n"&gt;hidden_states&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transformer&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="n"&gt;last_hidden_state&lt;/span&gt;  &lt;span class="c1"&gt;# (B, T, D)
&lt;/span&gt;
    &lt;span class="n"&gt;all_logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;all_gates&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;all_ww&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;input_ids&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;size&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)):&lt;/span&gt;
        &lt;span class="n"&gt;h_t&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hidden_states&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt;                     &lt;span class="c1"&gt;# (B, D) — current hidden state
&lt;/span&gt;
        &lt;span class="c1"&gt;# Memory interaction for this timestep
&lt;/span&gt;        &lt;span class="n"&gt;read_vec&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;memory&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;usage&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;write_gate&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;w_write&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;memory&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;h_t&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;memory&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;usage&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="c1"&gt;# Fuse read vector back into hidden state
&lt;/span&gt;        &lt;span class="n"&gt;h_out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h_t&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;read_proj&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;read_vec&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;           &lt;span class="c1"&gt;# residual addition
&lt;/span&gt;
        &lt;span class="n"&gt;all_logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;lm_head&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;h_out&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;           &lt;span class="c1"&gt;# project to vocab
&lt;/span&gt;        &lt;span class="n"&gt;all_gates&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;write_gate&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;all_ww&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w_write&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="n"&gt;logits&lt;/span&gt;      &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;all_logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;         &lt;span class="c1"&gt;# (B, T, V)
&lt;/span&gt;    &lt;span class="n"&gt;write_gates&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;all_gates&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;          &lt;span class="c1"&gt;# (B, T, 1)
&lt;/span&gt;    &lt;span class="n"&gt;w_writes&lt;/span&gt;    &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stack&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;all_ww&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;             &lt;span class="c1"&gt;# (B, T, N)
&lt;/span&gt;    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;memory&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;usage&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;write_gates&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;w_writes&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;Why the sequential loop?&lt;/strong&gt; Memory has a causal dependency — &lt;code&gt;memory[t]&lt;/code&gt; depends on what was written at steps &lt;code&gt;0..t-1&lt;/code&gt;. This can't be parallelized like self-attention. It's the main compute overhead of DNC over a pure transformer.&lt;/p&gt;

&lt;h3&gt;
  
  
  Memory Initialization
&lt;/h3&gt;

&lt;p&gt;Memory is initialized to zeros at the start of each sequence:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;init_memory&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;memory&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cfg&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mem_slots&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cfg&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mem_width&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;usage&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;cfg&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;mem_slots&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;device&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;memory&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;usage&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This means memory is &lt;strong&gt;per-sequence&lt;/strong&gt;, not persistent across batch items or between training steps. It acts as within-sequence working memory, not a cross-sequence knowledge base.&lt;/p&gt;




&lt;h2&gt;
  
  
  Metrics to Watch
&lt;/h2&gt;

&lt;p&gt;During training, several metrics beyond loss reveal whether the memory system is working correctly:&lt;/p&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;Metric&lt;/th&gt;
&lt;th&gt;What It Tells You&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;avg_gate&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Mean write gate activation. Should settle between 0.2–0.7; too high = writing everything, too low = never writing&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;gate_std&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Gate polarization. High std means the model discriminates — writes on some tokens, skips others&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;write_rate&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;Fraction of timesteps with gate &amp;gt; 0.7. Tracks how aggressively the model uses memory&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;write_sparsity&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;How concentrated the write weighting is. High sparsity = sharp slot selection&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;&lt;code&gt;mem_kl&lt;/code&gt;&lt;/td&gt;
&lt;td&gt;KL divergence between memory and no-memory predictions. Non-zero means memory is changing outputs&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;A healthy DNC should show &lt;strong&gt;high gate_std&lt;/strong&gt; (selective writing) and &lt;strong&gt;high write_sparsity&lt;/strong&gt; (concentrated writes), with &lt;strong&gt;non-trivial mem_kl&lt;/strong&gt; (memory actually matters).&lt;/p&gt;




&lt;h2&gt;
  
  
  Practical Configuration
&lt;/h2&gt;

&lt;p&gt;The config used in the experiments:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Config&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="c1"&gt;# GPT-2 backbone
&lt;/span&gt;    &lt;span class="n"&gt;hidden_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;768&lt;/span&gt;
    &lt;span class="n"&gt;num_layers&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;
    &lt;span class="n"&gt;num_heads&lt;/span&gt;   &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
    &lt;span class="n"&gt;seq_len&lt;/span&gt;     &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;

    &lt;span class="c1"&gt;# DNC memory
&lt;/span&gt;    &lt;span class="n"&gt;mem_slots&lt;/span&gt;   &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;64&lt;/span&gt;     &lt;span class="c1"&gt;# N: number of memory slots
&lt;/span&gt;    &lt;span class="n"&gt;mem_width&lt;/span&gt;   &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;    &lt;span class="c1"&gt;# W: width of each slot
&lt;/span&gt;    &lt;span class="n"&gt;num_reads&lt;/span&gt;   &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;      &lt;span class="c1"&gt;# R: number of read heads
&lt;/span&gt;
    &lt;span class="c1"&gt;# Loss weights
&lt;/span&gt;    &lt;span class="n"&gt;lambda_routing&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;
    &lt;span class="n"&gt;lambda_entropy&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.05&lt;/span&gt;

    &lt;span class="c1"&gt;# Training
&lt;/span&gt;    &lt;span class="n"&gt;batch_size&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;
    &lt;span class="n"&gt;lr&lt;/span&gt;          &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;3e-4&lt;/span&gt;
    &lt;span class="n"&gt;grad_clip&lt;/span&gt;   &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1.0&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;Memory footprint&lt;/strong&gt;: The external memory adds &lt;code&gt;N × W = 64 × 128 = 8,192&lt;/code&gt; floats per batch item — negligible compared to the model weights themselves. The overhead is in the sequential forward loop, not storage.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Parameter count&lt;/strong&gt;: DNC adds roughly &lt;code&gt;5 × (D × W)&lt;/code&gt; parameters from the five projection matrices. At &lt;code&gt;D=768, W=128&lt;/code&gt; that's ~490K parameters — about 0.5% overhead on a 6-layer GPT-2.&lt;/p&gt;




&lt;h2&gt;
  
  
  Limitations and What's Next
&lt;/h2&gt;

&lt;p&gt;This architecture is a proof of concept. Several known limitations:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Sequential bottleneck&lt;/strong&gt;: The time-step loop cannot be parallelized. For long sequences, this significantly slows training relative to the pure-transformer baseline.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;No cross-sequence persistence&lt;/strong&gt;: Memory resets between sequences. A truly useful factual memory would persist across the lifetime of the model — closer to a retrieval-augmented generation (RAG) system.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Gradient flow through time&lt;/strong&gt;: Backpropagating through &lt;code&gt;T&lt;/code&gt; sequential memory steps can cause vanishing/exploding gradients for long sequences. Gradient clipping (&lt;code&gt;grad_clip = 1.0&lt;/code&gt;) helps but doesn't solve it.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Potential extensions&lt;/strong&gt;:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;strong&gt;Persistent memory&lt;/strong&gt;: Keep a global memory matrix that accumulates knowledge across a training corpus and is frozen at inference time (like a learned knowledge base)&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Sparse attention writes&lt;/strong&gt;: Replace soft write weighting with a top-k hard selection to reduce memory write diffusion&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Layer-wise memory&lt;/strong&gt;: Attach a memory module to each transformer layer, not just the final hidden state&lt;/li&gt;
&lt;li&gt;
&lt;strong&gt;Memory-augmented RAG&lt;/strong&gt;: Use DNC writes as an online summary buffer, and retrieve from it alongside a static vector DB&lt;/li&gt;
&lt;/ul&gt;




&lt;h2&gt;
  
  
  Summary
&lt;/h2&gt;

&lt;div class="table-wrapper-paragraph"&gt;&lt;table&gt;
&lt;thead&gt;
&lt;tr&gt;
&lt;th&gt;&lt;/th&gt;
&lt;th&gt;GPT-2 Baseline&lt;/th&gt;
&lt;th&gt;GPT-2 + DNC&lt;/th&gt;
&lt;/tr&gt;
&lt;/thead&gt;
&lt;tbody&gt;
&lt;tr&gt;
&lt;td&gt;Factual recall&lt;/td&gt;
&lt;td&gt;Parametric only&lt;/td&gt;
&lt;td&gt;Parametric + external memory&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Memory type&lt;/td&gt;
&lt;td&gt;Weights (static)&lt;/td&gt;
&lt;td&gt;N×W matrix (dynamic, per-sequence)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Write mechanism&lt;/td&gt;
&lt;td&gt;None&lt;/td&gt;
&lt;td&gt;Content + allocation addressing&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Selective writing&lt;/td&gt;
&lt;td&gt;No&lt;/td&gt;
&lt;td&gt;Yes (learned write gate)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Extra parameters&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;td&gt;~490K (~0.5%)&lt;/td&gt;
&lt;/tr&gt;
&lt;tr&gt;
&lt;td&gt;Training overhead&lt;/td&gt;
&lt;td&gt;—&lt;/td&gt;
&lt;td&gt;Sequential loop over T steps&lt;/td&gt;
&lt;/tr&gt;
&lt;/tbody&gt;
&lt;/table&gt;&lt;/div&gt;

&lt;p&gt;The DNC doesn't replace the transformer's parametric knowledge — it &lt;em&gt;supplements&lt;/em&gt; it. The model learns when to trust its weights and when to externalise a fact to the notepad. On a small model operating in a domain with many precise facts, that notepad can make all the difference.&lt;/p&gt;

&lt;p&gt;The write gate is the centrepiece of the design. When it fires on "Einstein" and "1879" and stays quiet on "was" and "the", you know the model has learned something non-trivial: &lt;strong&gt;not all tokens are worth remembering&lt;/strong&gt;.&lt;/p&gt;




&lt;p&gt;Github Code: &lt;a href="https://github.com/AsishKumarDalal/memoryllm" rel="noopener noreferrer"&gt;https://github.com/AsishKumarDalal/memoryllm&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;em&gt;Implementation: PyTorch. Dataset: WikiText-2. Backbone: GPT-2 (6 layers, 768 hidden, 8 heads). DNC config: N=64, W=128, R=4 read heads. Loss: L_lm + 0.1·L_routing + 0.05·L_entropy.&lt;/em&gt;&lt;/p&gt;

</description>
      <category>ai</category>
      <category>python</category>
    </item>
  </channel>
</rss>
