<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Hưng Lê Tiến</title>
    <description>The latest articles on DEV Community by Hưng Lê Tiến (@blackbidz).</description>
    <link>https://dev.to/blackbidz</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F3603844%2F8e042e5e-07ae-44cc-8d29-087ac8c4c423.jpg</url>
      <title>DEV Community: Hưng Lê Tiến</title>
      <link>https://dev.to/blackbidz</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/blackbidz"/>
    <language>en</language>
    <item>
      <title>BIG STEPS TO TRANSFORMER (PART 2): BUILDING THE TRANSFORMER</title>
      <dc:creator>Hưng Lê Tiến</dc:creator>
      <pubDate>Sat, 06 Dec 2025 10:02:29 +0000</pubDate>
      <link>https://dev.to/blackbidz/big-steps-to-transformer-part-2-building-the-transformer-302f</link>
      <guid>https://dev.to/blackbidz/big-steps-to-transformer-part-2-building-the-transformer-302f</guid>
      <description>&lt;p&gt;This might be the end of the blogposts I made based on the videos of Andrej, so shout out to him the last time. Respect!&lt;/p&gt;

&lt;p&gt;A word of warning: I'm currently working on projects now so things got messy (going around and fixing bugs everyday). It's hard to systematically note things down, and also sometimes I just want to talk really briefly about the topic. So maybe you won't find like a detailed explanation in my blogposts anymore, some notions will be under-explained or even wrongly explained. I'm just sharing my thoughts and some of the insights that I came up with, so don't take it serious on me okay :).&lt;/p&gt;

&lt;p&gt;Now we know that, in the Bigram model, the model only uses &lt;strong&gt;just the previous token&lt;/strong&gt; to predict the next, which really limits the context and other important things. In other words, the &lt;strong&gt;context window&lt;/strong&gt; is too small, and we really want to increase it. But when we discuss about scaling the N-gram model, we encounter a big problem: The computation complexity &lt;strong&gt;grows exponentially&lt;/strong&gt;. &lt;/p&gt;

&lt;p&gt;And this is when our Attention boy comes in. Attention allows us to look at everything in the past (regarding a sequence of size &lt;code&gt;block_size&lt;/code&gt;), and we will walkthrough it from the naive approach to the actual version.&lt;/p&gt;

&lt;h2&gt;
  
  
  The Naive Approach
&lt;/h2&gt;

&lt;p&gt;Let me go specific: What we want here is, for each timestep (as we walk through the time dimension), we would love to see &lt;strong&gt;every character&lt;/strong&gt; behind us in order to make our decision. &lt;/p&gt;

&lt;p&gt;We do that by a fairly simple way: Just carry every of the data of the previous characters with us by a &lt;strong&gt;weighted sum&lt;/strong&gt;, and in our first case, we will just take the mean. Here's how it looks like with a toy example:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;
&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;xbow&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;xprev&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,:&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# (t,C)
&lt;/span&gt;    &lt;span class="n"&gt;xbow&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xprev&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;&amp;gt;&amp;gt;&amp;gt;x[0]
tensor([[ 0.3624, -1.6396],
        [-0.3599, -1.2083],
        [-2.0063,  0.8857],
        [ 0.0144,  1.5211],
        [ 0.6635, -0.5929],
        [ 1.0220,  0.8683],
        [ 1.2717, -0.6242],
        [-1.4394,  0.2805]])
&amp;gt;&amp;gt;&amp;gt;xbow[0]
tensor([[ 3.6235e-01, -1.6396e+00],
        [ 1.2276e-03, -1.4240e+00],
        [-6.6794e-01, -6.5406e-01],
        [-4.9734e-01, -1.1026e-01],
        [-2.6517e-01, -2.0679e-01],
        [-5.0639e-02, -2.7610e-02],
        [ 1.3827e-01, -1.1284e-01],
        [-5.8944e-02, -6.3670e-02]])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And you know what? All of those things can be done with matrix multiplication. I won't go deep into the detail here, it's the responsibility of your Algebra lecturer. So let me just put the code right under this:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tril&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,(&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)).&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;a&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;a=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;a&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;b=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;c=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Few things to notice here, first we have that &lt;code&gt;tril&lt;/code&gt; thing, which does the job of converting the matrix into a &lt;em&gt;lower triangular&lt;/em&gt; one, and the intuition behind this is that: No token is allowed to talk to the ones in the future, it can &lt;strong&gt;only look to the past&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;Second, notice that we divided each row by the sum. This is to ensure that we take the &lt;em&gt;mean&lt;/em&gt; of the aggregated sum.&lt;/p&gt;

&lt;p&gt;Now let's implement that in our real input matrix &lt;code&gt;x&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tril&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;wei&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;xbow2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="c1"&gt;# (B, T, T) @ (B, T, C) ----&amp;gt; (B, T, C)
&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xbow&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;xbow2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That will return &lt;code&gt;True&lt;/code&gt;, trust me. And notice the comment I made: There is a bit of broadcasting there, because the dimensions don't really match (we're dealing with batches). The broadcasting will just replicate all of the &lt;code&gt;TxT&lt;/code&gt; matrices to distribute across the batch dim, so things are good here.&lt;/p&gt;

&lt;p&gt;In practice, though, we would get rid of that normalizing part and rather apply the &lt;strong&gt;softmax&lt;/strong&gt;, which I will show you:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;wei&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tril&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-inf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;wei&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;xbow3&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;allclose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xbow&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;xbow3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Remember that we also want to create a lower triangular matrix and then normalize, but how can we tell softmax that there are some zeros? We put the &lt;code&gt;-inf&lt;/code&gt; in  each places where it's supposed to be a &lt;code&gt;0&lt;/code&gt; there, in doing so, we will eventually create 0 because softmax take the &lt;strong&gt;exponential&lt;/strong&gt; of the values. &lt;/p&gt;

&lt;p&gt;A small note here before we jump into the Attention part: We need to somehow tell the token which &lt;em&gt;position&lt;/em&gt; it's in. Like "He sat on the mat, with the cat" and "The cat sat on the mat" for example, the relative order between "sat" and the nouns ("he", "cat) are crucial in understanding the meaning of the sentence. So &lt;em&gt;order matters&lt;/em&gt;, then we shall introduce: &lt;strong&gt;Positional Embedding&lt;/strong&gt;, which is just adding an &lt;code&gt;Embedding&lt;/code&gt; layer for each position, and then we add that thing to our &lt;code&gt;token_embedding&lt;/code&gt;. Then we're all set:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# It kinda goes like this
&lt;/span&gt;&lt;span class="n"&gt;pos_emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;posemb_matrix&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arrange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tok_emd&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;pos_emb&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  The Crux of Self Attention
&lt;/h2&gt;

&lt;p&gt;Now first there are some problems with our naive approach: It treats each of the previous tokens as equal contributors, while in practice, some tokens can be &lt;strong&gt;more valuable&lt;/strong&gt; in the eye of other tokens. In other word, we would love to get rid of our naive everything-equal approach to adopt a kind of &lt;strong&gt;weighted sum&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;To compute that weighted sum, we need to find the &lt;strong&gt;relationship&lt;/strong&gt; of the examined token with all the previous. And we do that in a fairly easy-to-understand way: Let each token has two vectors, &lt;strong&gt;Q&lt;/strong&gt; (for query), and &lt;strong&gt;K&lt;/strong&gt; (for key). It goes like this:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;When we march through the time step, each token will emit it's &lt;strong&gt;query vector&lt;/strong&gt;, like asking the previous: I'm looking for A,B,C,... It might look for an adjective, or a noun, or a word to make it's context more specific, etc.&lt;/li&gt;
&lt;li&gt;After emitting the query vector, we take out the each of the &lt;strong&gt;keys&lt;/strong&gt; from previous tokens, which is simply the answer for the query. And we measure that answer by take the &lt;strong&gt;dot product&lt;/strong&gt; of two vectors (remember that dot products actually measure the &lt;em&gt;similarity&lt;/em&gt; between vectors). And then we have our weights!&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;But we won't use that directly to our &lt;code&gt;x&lt;/code&gt;, instead, we introduce another vector: the &lt;strong&gt;V&lt;/strong&gt;, represents for the values of each token. The original vectors for each token are kinda private information, and value is a way to tell "If you are interested in me, here's what I will give you, but I won't take myself with you, ain't no way bud".&lt;/p&gt;

&lt;p&gt;And all of that, my friend, is the crux of self-attention. You might want to check a great video and visualization from 3Blue1Brown (and check his series), those videos are phenomenal: &lt;a href="https://youtu.be/eMlx5fFNoYc?si=i7zVZObOevccKiQ-" rel="noopener noreferrer"&gt;Attention in transformers&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Now let's implement in the code:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;head_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;16&lt;/span&gt;
&lt;span class="c1"&gt;# This is not a layer, it's just a nice way to
# construct our weight matrices
&lt;/span&gt;&lt;span class="n"&gt;key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;head_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;bias&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;query&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;head_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;bias&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;key&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# B,T,16
&lt;/span&gt;&lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;query&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Need to tranpose k
&lt;/span&gt;&lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;q&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;tril&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tril&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;wei&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;masked_fill&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tril&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;-inf&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;wei&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# Creating the value vectors
&lt;/span&gt;&lt;span class="n"&gt;value&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;head_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;bias&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;False&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;v&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;value&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;wei&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;v&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;A small thing to notice here is that we often &lt;strong&gt;transpose&lt;/strong&gt; the K matrix and then multiply it with Q, so it's a more convenient way to take the dot products.&lt;/p&gt;

&lt;p&gt;That's all of the basics about Self-attention, and there are few notes from Andrej that you should read. I won't go specific because I'm lazy (sorry):&lt;/p&gt;

&lt;h3&gt;
  
  
  Notes:
&lt;/h3&gt;

&lt;ul&gt;
&lt;li&gt;Attention is a &lt;strong&gt;communication mechanism&lt;/strong&gt;. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.&lt;/li&gt;
&lt;li&gt;There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.&lt;/li&gt;
&lt;li&gt;Each example across batch dimension is of course processed completely independently and never "talk" to each other&lt;/li&gt;
&lt;li&gt;In an "encoder" attention block just delete the single line that does masking with &lt;code&gt;tril&lt;/code&gt;, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.&lt;/li&gt;
&lt;li&gt;"self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)&lt;/li&gt;
&lt;li&gt;"Scaled" attention additional divides &lt;code&gt;wei&lt;/code&gt; by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. &lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  Toward the Transformer
&lt;/h2&gt;

&lt;p&gt;Now we're not done yet! We have a lot of thing to do to come up with the final model. First let's take a look at the architecture in the paper &lt;a href="https://arxiv.org/pdf/1706.03762" rel="noopener noreferrer"&gt;Attention Is All You Need&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8ci6wueqcmgmkxx122ia.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F8ci6wueqcmgmkxx122ia.png" alt="Architecture" width="695" height="525"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  Multi-headed attention
&lt;/h3&gt;

&lt;p&gt;We did that by connecting layers of self-attention (which we call &lt;code&gt;Head&lt;/code&gt;). The intuition is that we want the model to ask many questions as possible, and thus understand more about the context. &lt;/p&gt;

&lt;p&gt;For the code, oftentimes we divide the &lt;code&gt;head_size&lt;/code&gt; by the number of heads that we want and then after the multi-head layer we concatenate them &lt;strong&gt;in the C dimension&lt;/strong&gt;. It's like dividing the features into small ones for the model to process.&lt;/p&gt;

&lt;h3&gt;
  
  
  FeedForward layer
&lt;/h3&gt;

&lt;p&gt;After some layers of attention, we need to allow our neurons to have some additional amount of time to process the information that they have, and even communicate with each other more, in order to arrive at something good. I know that I'm describing this in a overly-metaphoric way, but I think it's the intuition behind adding a MLP after the Multi-head.&lt;/p&gt;

&lt;p&gt;So let's implement it, shall we? We just need to add a Linear layer and followed by a non-linear activation function (which in this case is the &lt;em&gt;ReLU&lt;/em&gt;).&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;FeedFoward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt; a simple linear layer followed by a non-linearity &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;net&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;net&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Residual connections
&lt;/h3&gt;

&lt;p&gt;The Multi-head and the Feedforward are the bare bone of our model, so now we know the important parts in Transformer.&lt;/p&gt;

&lt;p&gt;We shall move on to some smaller, yet still important details in the architecture. For the first, I would like to tell you about &lt;strong&gt;Residual Connection&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;Remember the problem of &lt;em&gt;Vanishing Gradient&lt;/em&gt;? It happens when the gradient has to be passed down on numerous consecutive layers, got squished and sometimes zeroed out (mainly because of some saturated activation functions), and eventually results in minimal to zero update for the parameters.&lt;/p&gt;

&lt;p&gt;In other words, the gradient has to flow in some winding, twisty roads to bring some updates back to our layers. Well, so why don't we construct another branch of road that allows our gradient to flow smoothly without having to go through any additional pathway? And that's how we construct the &lt;strong&gt;residual pathway&lt;/strong&gt;, which is really like a highway for the gradients. You can look back at the architecture illustration, and you can see there are additional arrows at the side of the model, which show that we plug the input right after calculating the output by adding those things up.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# In our block, we modify this
&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sa&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ffwd&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;If you are not really comfortable with this metaphoric way of thinking, then you can think in mathematical term: When the gradient flows back, the &lt;em&gt;addition operation&lt;/em&gt; distribute the gradient &lt;strong&gt;equally&lt;/strong&gt; for the previous nodes. So the input still receives some unmodified, unsquashed gradients through the backward pass. And this residual connection actually opens a new architecture of neural net: &lt;strong&gt;Wide and Deep Neural Network&lt;/strong&gt;. (You can search that, it's really interesting).&lt;/p&gt;

&lt;p&gt;A small detail here, is that we also need two projection layers, one after the Multi-head, and one after the FeedForward. This acts as a &lt;strong&gt;mixer&lt;/strong&gt;, we mix things up in order to allow for more exchange of information among heads. We add just a linear layer of size &lt;code&gt;n_emb x n_emb&lt;/code&gt;(Notice that the number of input and output are equal, so basically we're just mixing things up).&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;#Adding a projection layer
&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;proj&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_emb&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;n_emb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Also notice in the paper, in the FeedForward layer, the researchers multiply the inner-layer by 4 (creating a much larger working space for our neurons, I think), so we should also include that in our FeedForward layer.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;net&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  LayerNorm
&lt;/h3&gt;

&lt;p&gt;There is actually a whole family of normalization techniques, and we've already known one of them, which is the BatchNorm. But we also know that BatchNorm has some problems on its own, including the side effect of &lt;em&gt;grouping instances&lt;/em&gt; together, and we also have to keep track of the &lt;em&gt;population mean and variance&lt;/em&gt;. It turns out, that BatchNorm actually doesn't work well with sequential inputs like this, mainly because we don't pass data in &lt;em&gt;in batches&lt;/em&gt;, but rather &lt;em&gt;each instance one by one&lt;/em&gt;, and researchers have found a way to combat the issue.&lt;/p&gt;

&lt;p&gt;It is LayerNorm, where we don't normalize data by it's batch dimension, but rather in the &lt;strong&gt;feature dimension&lt;/strong&gt;, or in our case, it is the &lt;em&gt;channels&lt;/em&gt; or the &lt;em&gt;embedding dimension&lt;/em&gt;. So we don't even need to compute the running mean and variance anymore, because now each instance has its own mean and variance, regardless of the batch it's in.&lt;/p&gt;

&lt;p&gt;To compute it, we just need to change the dimension when normalizing: Normalizing rows rather than columns. And that's it.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;1e-5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;momentum&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;eps&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;momentum&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;momentum&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;training&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
    &lt;span class="c1"&gt;# parameters (trained with backprop)
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__call__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;xmean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# layer mean
&lt;/span&gt;    &lt;span class="n"&gt;xvar&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;var&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# layer variance
&lt;/span&gt;    &lt;span class="n"&gt;xhat&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;xmean&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sqrt&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xvar&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;eps&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;gamma&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;xhat&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;beta&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;In recent years, there have been so many improvements and new findings. The first is that we don't add the Norm layer after the Multi-head or the Feedforward, but &lt;strong&gt;before&lt;/strong&gt; them.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Block&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt; Transformer block: communication followed by computation &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_head&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;head_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt; &lt;span class="o"&gt;//&lt;/span&gt; &lt;span class="n"&gt;n_head&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sa&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_head&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;head_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ffwd&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;FeedFoward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="c1"&gt;# Creating our LayerNorm
&lt;/span&gt;        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ln1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;ln2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;LayerNorm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# We apply the LN first, then others
&lt;/span&gt;        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sa&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ln1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ffwd&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ln2&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;They also discovered that LayerNorm has some problems on its own, mainly in the field of model's interpretability. Researchers took that seriously and invented numerous normalization techniques, namely &lt;em&gt;Instance Norm&lt;/em&gt;, &lt;em&gt;GroupNorm&lt;/em&gt; (which is some of LayerNorm + some of Instance Norm), &lt;em&gt;RMSNorm&lt;/em&gt;(currently used, LayerNorm but we don't subtract the mean and we divide the inputs by the Root Mean Square, rather than the Variance).&lt;/p&gt;

&lt;h3&gt;
  
  
  Dropout
&lt;/h3&gt;

&lt;p&gt;The last small thing, before we wrap up the model, is the Dropout. Dropout is just additional layer mainly used for regularization. It works by randomly "killing" some neurons in the process, making the collaboration game harder, and eventually, developing a kind of &lt;em&gt;independence&lt;/em&gt; of neurons, thus making them stronger (No pain no gain, right? I often wonder why these neurons are always treated so softly, I mean we should be strict sometimes, and Dropout really does the trick).&lt;/p&gt;

&lt;p&gt;Dropout is also added after we calculate the Key-Query weight matrix in the Attention layer, so that when multiplying with the value vectors, it prevents some communication between words.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Dropout in the Multihead and the Linear layer
&lt;/span&gt;&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;MultiHeadAttention&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt; multiple heads of self-attention in parallel &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;head_size&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ModuleList&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="nc"&gt;Head&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;head_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num_heads&lt;/span&gt;&lt;span class="p"&gt;)])&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;proj&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="nf"&gt;h&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;heads&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;proj&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;FeedFoward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="sh"&gt;"""&lt;/span&gt;&lt;span class="s"&gt; a simple linear layer followed by a non-linearity &lt;/span&gt;&lt;span class="sh"&gt;"""&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;net&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ReLU&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Dropout&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;net&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And that's all I want to tell.&lt;/p&gt;

&lt;h2&gt;
  
  
  Summary
&lt;/h2&gt;

&lt;p&gt;I know that this kind of explanation won't please you, I introduced concepts rather in an abrupt way, sometimes I even skip some parts. But if you understand what I mean, then congrats, my friend, you've just walked through one of the most brilliant architecture in human history.&lt;/p&gt;

&lt;p&gt;The Transformer is just blocks of Multihead and FeedForward and all other stuff stacking onto each other, that's it! Now I think I gotta build this as one of my projects, so goodbye for now.&lt;/p&gt;

&lt;p&gt;Thanks for reading and, as usual, having a good day!&lt;/p&gt;

</description>
      <category>architecture</category>
      <category>deeplearning</category>
      <category>llm</category>
    </item>
    <item>
      <title>BIG STEPS TO TRANSFORMER (PART 1): BUILDING THE BIGRAM</title>
      <dc:creator>Hưng Lê Tiến</dc:creator>
      <pubDate>Sun, 30 Nov 2025 09:39:32 +0000</pubDate>
      <link>https://dev.to/blackbidz/big-steps-to-transformer-part-1-building-the-bigram-1g9b</link>
      <guid>https://dev.to/blackbidz/big-steps-to-transformer-part-1-building-the-bigram-1g9b</guid>
      <description>&lt;p&gt;I'm back! It took a while for me to grasp all of the basics about Transformer, and we're now ready to jump right into that big boy. The days of playing with casual neural net stuff are gone, we shall behold the most modern architecture in this world, and let's break it down step by step.&lt;/p&gt;

&lt;p&gt;This blogpost is, again, based from the series about Neural Nets from Andrej Karpathy. Big respect to him.&lt;/p&gt;

&lt;p&gt;Now let's start our journey, we will start simple, really simple, by calling again our primitive language model: &lt;strong&gt;The Bigram Language Model&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;You may ask: &lt;em&gt;Why is the point of this?&lt;/em&gt; Then relax, my friend, even though the bigram is an extremely simple model, it provides a great setup for our big boy. To be more specific, in this blog, we will rebuild the bigram in a more &lt;em&gt;general&lt;/em&gt; way, in which we will adjust step by step to eventually converge with a Transformer. So hold your beer for now.&lt;/p&gt;

&lt;h2&gt;
  
  
  Load the Data
&lt;/h2&gt;

&lt;p&gt;We will consider a much bigger dataset, not a list of American names, but &lt;strong&gt;Scripts from Shakespeare&lt;/strong&gt;. It's somehow preferred by Andrej, so we will stick with that. And as you can probably guess, it is &lt;em&gt;way larger&lt;/em&gt; than the corpus we use in the previous project.&lt;/p&gt;

&lt;p&gt;Let's download the data for our Colab:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
&lt;/span&gt;&lt;span class="err"&gt;!&lt;/span&gt;&lt;span class="n"&gt;wget&lt;/span&gt; &lt;span class="n"&gt;https&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;//&lt;/span&gt;&lt;span class="n"&gt;raw&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;githubusercontent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;com&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;karpathy&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;char&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;rnn&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;master&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;tinyshakespeare&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="nb"&gt;input&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;txt&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And then read the data:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# read it in to inspect it
&lt;/span&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;input.txt&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;encoding&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;utf-8&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;read&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  Tokenize
&lt;/h2&gt;

&lt;p&gt;Just like we did in every of our previous projects, we always have to find a way to &lt;em&gt;encode&lt;/em&gt; our text, so that it can be fed into the neural nets, and of course we should also have a &lt;em&gt;decode&lt;/em&gt; step to convert the numerical representation back to the text-like format. Now first we just store the characters to construct the mapping later, and also to get the &lt;code&gt;vocab_size&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Construct the list of chars + vocab_size
&lt;/span&gt;&lt;span class="n"&gt;chars&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;sorted&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;set&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;span class="n"&gt;vocab_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chars&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;''&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chars&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We might want to see our chars, which is different from things in the previous projects:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;
 !$&amp;amp;',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Alongside with alphabetical chars, we also got numeric, or even special ones. Actually the &lt;code&gt;char[0]&lt;/code&gt; is the &lt;code&gt;\newline&lt;/code&gt;, which is interesting, and our &lt;code&gt;vocab_size&lt;/code&gt; is now 65, not 27 anymore.&lt;/p&gt;

&lt;p&gt;Now we should create a mapping to convert all characters into integers, and then apply a function to encode the string into numbers (remember the inverse mapping and the decode step). Unlike the previous project, now we have to create the &lt;code&gt;encode&lt;/code&gt; and &lt;code&gt;decode&lt;/code&gt; function because we are dealing with a lot of characters, so things get complicated.&lt;/p&gt;

&lt;p&gt;Here's the code:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# create a mapping from characters to integers
&lt;/span&gt;&lt;span class="n"&gt;stoi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt; &lt;span class="n"&gt;ch&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;ch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chars&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="p"&gt;}&lt;/span&gt;
&lt;span class="n"&gt;itos&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;ch&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;ch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chars&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="p"&gt;}&lt;/span&gt;
&lt;span class="n"&gt;encode&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;c&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;c&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# encoder: take a string, output a list of integers
&lt;/span&gt;&lt;span class="n"&gt;decode&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="sh"&gt;''&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;l&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="c1"&gt;# decoder: take a list of integers, output a string
&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;encode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;hii there&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;encode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;hii there&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;)))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;It works great! Now let's turn the text into integers:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# let's now encode the entire text dataset and store it into a torch.Tensor
&lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt; &lt;span class="c1"&gt;# we use PyTorch: https://pytorch.org
&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;encode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;long&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; &lt;span class="c1"&gt;# the 10 characters we looked at earier will to the GPT look like this
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Beautiful! Just one more note here: This process is called &lt;strong&gt;Tokenization&lt;/strong&gt;. And the tokenization we're using is the most naive and the simplest form, which is the character-level tokenization. In practice, researchers have developed a whole bunch of methods for tokenizing text. The most groundbreaking ones, and actually it is used in LLMs today, is &lt;strong&gt;sub-word tokenization&lt;/strong&gt; (like the &lt;em&gt;SentencePiece&lt;/em&gt; from Google or &lt;em&gt;Tiktoken&lt;/em&gt; from OpenAI). There is a lot of things to talk about Tokenization so I will do a whole new blog for that (Promise, because Andrej has a video on that too).&lt;/p&gt;

&lt;h2&gt;
  
  
  Splitting the dataset
&lt;/h2&gt;

&lt;p&gt;From our previous project, we learned that in order to train the model effectively and reduce the risk of overfitting, we should split the dataset into subsets called "trainset", "devset", and "testset". We will apply the same thing to this project, by splitting the data into two parts, used for training or validation.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Let's now split up the data into train and validation sets
&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="c1"&gt;# first 90% will be train, rest val
&lt;/span&gt;&lt;span class="n"&gt;train_data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;val_data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  Creating the Context Window
&lt;/h2&gt;

&lt;p&gt;Now, remember in the previous model we have the &lt;code&gt;block_size&lt;/code&gt;, which is simply the length of the chunk of text that we want to feed in our model. In this project, we will increase our &lt;code&gt;block_size&lt;/code&gt;, or in the modern world people call it &lt;code&gt;context_window&lt;/code&gt;, to 8. This allows our model to look back further, thus having a better understanding of the &lt;em&gt;context&lt;/em&gt; that it's in.&lt;br&gt;
Another thing to note here, is that we're not going to predict the char by just looking at it's previous 8 characters like in the past model, but we will predict the next char in &lt;strong&gt;each of the position&lt;/strong&gt; in the context window. That means, in the chunk of 8 characters, first we show the first 1 char and the model predicts the next, then we show the first 2 chars and the model predicts, the list goes on until the model gets to predict the whole sequence of 8. &lt;br&gt;
In other words, we are not using the "sliding window" technique anymore, but rather we're revealing one character at a time in a chunk of text. Now this is an important moment: Because of that sequential processing, we would rather call the &lt;code&gt;block_size&lt;/code&gt; as the &lt;strong&gt;Time steps&lt;/strong&gt;, specifying the &lt;em&gt;time&lt;/em&gt; for our predictions and thus signaling a kind of sequential processing. (This is a cross-over with the &lt;strong&gt;Recurrent Neural Network(RNN)&lt;/strong&gt; if you know it)&lt;/p&gt;

&lt;p&gt;Now to make clear to you what I mean, we will introduce that in the code:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;block_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;8&lt;/span&gt;
&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;train_data&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="n"&gt;block_size&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="c1"&gt;# We also construct the target,
# which is just the x but we offset by a character
&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;train_data&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;block_size&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; 

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;t&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;block_size&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="n"&gt;target&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;t&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;when input is &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; the target: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;target&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is a different way of prediction, and keep that in mind, because it is the barebone of the Transformer when we scrutinize deeper. Remember that we're doing the setup for that.&lt;/p&gt;

&lt;h2&gt;
  
  
  Creating Batches
&lt;/h2&gt;

&lt;p&gt;Now we know about the time dimension, we should also know about the &lt;strong&gt;Batch dimension&lt;/strong&gt;. In training, we would love to do things &lt;em&gt;in parallel&lt;/em&gt;, means that in each time step, we want to process more than one thing, separately (It's like creating numerous parallel universe working independently). And that's actually the &lt;code&gt;batch_size&lt;/code&gt;, the "creating parallel universe" part is just simply feeding the chunks of text &lt;em&gt;in batches&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;We create batches just like we did in the previous project, by initializing a tensor of random values, then use the values for indexing the chunks of text. For simplicity, the batch dimension will be just 4:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# For reproducibility, we introduce a manual seed
&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1337&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;4&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;get_batch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;train_data&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;split&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="n"&gt;val_data&lt;/span&gt;
  &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt;&lt;span class="p"&gt;,(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
  &lt;span class="c1"&gt;# The stack method creates the batch dimension for our data
&lt;/span&gt;  &lt;span class="c1"&gt;# simply by stacking chunks onto each other
&lt;/span&gt;  &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="n"&gt;block_size&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt; 
  &lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;stack&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="n"&gt;block_size&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
  &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;
&lt;span class="n"&gt;xb&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;yb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;get_batch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# 4x8 tensor
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And now we can use that data for our neural nets! &lt;/p&gt;

&lt;h2&gt;
  
  
  Build the model
&lt;/h2&gt;

&lt;p&gt;Let's look at the code first, and then I will explain later&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;
&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1337&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;BigramLanguageModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="c1"&gt;# each token directly reads off the logits for the next token from a lookup table
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;token_embedding_table&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# idx and targets are both (B,T) tensor of integers
&lt;/span&gt;    &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;token_embedding_table&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# B,T,C
&lt;/span&gt;    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;BigramLanguageModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xb&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;yb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;First we need to import the &lt;code&gt;Module&lt;/code&gt; from &lt;code&gt;torch.nn&lt;/code&gt; to get some methods that are convenient for our project. The thing to look here is the &lt;code&gt;Embedding&lt;/code&gt; method, it works just like when we embed our characters manually: It creates a lookup table, which is just a bunch of vectors (with the dimension depends on our initialization), and then when we want to embed something, we just need to put that in our lookup table, then the table will return an unique vector for it. &lt;/p&gt;

&lt;p&gt;In this project we create a 65x65 (as the &lt;code&gt;vocab_size&lt;/code&gt; is 65) lookup table, so each character is assigned with a vector storing 65 values, that's the &lt;strong&gt;embedding size&lt;/strong&gt;. The embedding size is up to our liking, it can be any numbers. &lt;em&gt;But why 65?&lt;/em&gt; you may ask, well, we're trying to replicate the &lt;strong&gt;Bigram model&lt;/strong&gt;, remember? And in the Bigram model we also have a lookup table of &lt;code&gt;vocab_size&lt;/code&gt;x&lt;code&gt;vocab_size&lt;/code&gt;, but that's just for counting the occurrence of pairs. And the hard truth is, actually we're not counting anything here.&lt;/p&gt;

&lt;p&gt;To be clear, this is not a &lt;strong&gt;count-based&lt;/strong&gt; Bigram model, but rather a &lt;strong&gt;neural net that replicates the bigram&lt;/strong&gt;, just like we did in the second blogpost. But if we optimize it well, then at the end of the day, the table will eventually converge to nearly the same as the actual count table (and our table represent the &lt;em&gt;log-count&lt;/em&gt;, not the normal count, I think we talked about that earlier).&lt;/p&gt;

&lt;p&gt;The last thing to keep in mind here, which is really important later. Note that the dimension of the data would increase after embedding, with an addition dimension for storing the embedding vectors. So the shape of the log-counts, or the logits would be &lt;code&gt;4x8x65&lt;/code&gt;, and you might notice the &lt;code&gt;B,T,C&lt;/code&gt; near it. That's important, and we will use that thing a lot later in the blog. The &lt;code&gt;B&lt;/code&gt; and &lt;code&gt;T&lt;/code&gt;, you might know, are the &lt;strong&gt;batch dimension&lt;/strong&gt; and the &lt;strong&gt;time dimension&lt;/strong&gt; of our data, what about &lt;code&gt;C&lt;/code&gt;? The &lt;code&gt;C&lt;/code&gt; here represents for the &lt;strong&gt;Channels&lt;/strong&gt;, which denotes the &lt;em&gt;depth of the meanings&lt;/em&gt; that each character holds. This shares a resemblance with &lt;strong&gt;Convolutional Neural Net&lt;/strong&gt; (CNN), in which we have the data of the size &lt;code&gt;(H,W,C)&lt;/code&gt;, with the 2 first dimension represents Height and Width, while the last represent the Channel (in a typical RGB image, one pixel have 3 channels of red, green, and blue). Be free to pause and ponder, and think about the &lt;code&gt;T&lt;/code&gt; and RNN also, then you will see that this is perhaps the second greatest unification in history (the greatest, in my opinion, is in The Avengers: Endgame).&lt;/p&gt;

&lt;p&gt;And a small note here, when we plug the index of our character in the lookup table, it actually return the &lt;em&gt;row&lt;/em&gt; of that index, with 65 values corresponding to the logits, and then we can apply softmax to find the probability of the next character.&lt;/p&gt;

&lt;p&gt;Now, for every neural network implementation, we would like to see the loss. We might simply use:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;But that would cause an error (try it yourself!). The reason here is that the function &lt;code&gt;cross_entropy&lt;/code&gt; in Pytorch expects the Channels as the second dimension, or in other words, it expects the logits to be of the size &lt;code&gt;B,C,T&lt;/code&gt; rather than &lt;code&gt;B,C,T&lt;/code&gt;. Kinda tired with that, now we have to adjust our logits, and the targets correspondingly.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Evaluate the loss
&lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.nn&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;
&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1337&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;BigramLanguageModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="c1"&gt;# each token directly reads off the logits for the next token from a lookup table
&lt;/span&gt;    &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;token_embedding_table&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="c1"&gt;# idx and targets are both (B,T) tensor of integers
&lt;/span&gt;    &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;token_embedding_table&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# B,T,C
&lt;/span&gt;    &lt;span class="c1"&gt;# loss = F.cross_entropy(logits, targets)
&lt;/span&gt;    &lt;span class="c1"&gt;# An error -&amp;gt; Pytorch expect the channel to be the second dim
&lt;/span&gt;    &lt;span class="c1"&gt;# Want B, T, C intead of B, C, T
&lt;/span&gt;    &lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;
    &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;# Stretch the 4x8 into a 1-dim of 32, preserve the channels
&lt;/span&gt;    &lt;span class="c1"&gt;# Do the same for the target (B,T)
&lt;/span&gt;    &lt;span class="n"&gt;targets&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;BigramLanguageModel&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xb&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;yb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor(4.8786, grad_fn=&amp;lt;NllLossBackward0&amp;gt;)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We want to know that whether this is a good initial loss our not (remember the hockey stick). So let's consider the average case, which is when the model assign equal probabilities for everything. In that case, you can do some arithmetic in your mind, the log-loss would be &lt;code&gt;-ln(1/65)&lt;/code&gt;, which is approximately &lt;code&gt;4.17&lt;/code&gt;. So we're starting with a not really bad loss! &lt;/p&gt;

&lt;h2&gt;
  
  
  Sampling
&lt;/h2&gt;

&lt;p&gt;Now we will talk about our method for sampling from the model. A key different here, though, is that we have an additional dimension called &lt;em&gt;Time-step&lt;/em&gt;, and also, we're processing the predictions *&lt;em&gt;in parallel&lt;/em&gt;, so when we have the prediction, it's actually a bunch of predictions in different universes and we will have to concat all of those thing.&lt;/p&gt;

&lt;p&gt;Let's look at the code first:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt; &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;generate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="c1"&gt;# idx is (B, T) array of indices in the current context
&lt;/span&gt;        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="c1"&gt;# get the predictions
&lt;/span&gt;            &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;self&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="c1"&gt;# focus only on the last time step
&lt;/span&gt;            &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;:]&lt;/span&gt; &lt;span class="c1"&gt;# becomes (B, C)
&lt;/span&gt;            &lt;span class="c1"&gt;# apply softmax to get probabilities
&lt;/span&gt;            &lt;span class="n"&gt;probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (B, C)
&lt;/span&gt;            &lt;span class="c1"&gt;# sample from the distribution
&lt;/span&gt;            &lt;span class="n"&gt;idx_next&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;multinomial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_samples&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (B, 1)
&lt;/span&gt;            &lt;span class="c1"&gt;# append sampled index to the running sequence
&lt;/span&gt;            &lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;idx_next&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (B, T+1)
&lt;/span&gt;        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;decode&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;generate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nb"&gt;long&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;max_new_tokens&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;()))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;First, we have the &lt;code&gt;idx&lt;/code&gt;, which is simply the first index, note that in the last line of code we pass in the &lt;code&gt;idx = torch.zeros(1,1)&lt;/code&gt;. This means that we just have one batch, and we're starting at the time step 1. The max time step for us is predetermined by the &lt;code&gt;max_new_tokens&lt;/code&gt;, and actually we're starting with the index &lt;code&gt;0&lt;/code&gt;, which is, interestingly, the &lt;code&gt;newline&lt;/code&gt; character (Like: Given a new line, give me the full script).&lt;/p&gt;

&lt;p&gt;We perform a forward pass with our &lt;code&gt;idx&lt;/code&gt;, and you can see that we pass in just the &lt;code&gt;idx&lt;/code&gt; alone, no &lt;code&gt;targets&lt;/code&gt;. We should cover that in the &lt;code&gt;forward&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

        &lt;span class="c1"&gt;# idx and targets are both (B,T) tensor of integers
&lt;/span&gt;        &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;token_embedding_table&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;idx&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (B,T,C)
&lt;/span&gt;
        &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt; &lt;span class="ow"&gt;is&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;
        &lt;span class="k"&gt;else&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
            &lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;
            &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;targets&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;B&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="n"&gt;T&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
            &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;targets&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Then, when taking the logits, remember that the logits will show the probability for the next character in &lt;strong&gt;each of the given position&lt;/strong&gt;, but as we're marching with &lt;strong&gt;time&lt;/strong&gt;, then we just need the prediction for the &lt;strong&gt;last time step&lt;/strong&gt;, that's why we have the &lt;code&gt;-1&lt;/code&gt; over there. You may see that this approach is rather ridiculous and funny because everytime we make a prediction, we just need the last one and ignore all of the previous (which is clearly a waste of computation). And talking about the bigram, remember the bigram only needs one character to predict one another? So what's the point here?&lt;/p&gt;

&lt;p&gt;It turns out, that the Transformer works just like that (remember we're trying to replicate it?). And the waste of computation is clear in the Transformer too! But researchers worked with that, and they introduce the &lt;strong&gt;KV Cache&lt;/strong&gt;, which does the trick. We won't talk about that, that's beyond the scope of this blog.&lt;/p&gt;

&lt;p&gt;Next is the sampling and concat part, which is easy to understand. Note that in the last line of code, we have the number &lt;code&gt;[0]&lt;/code&gt;, it's because the function will return a &lt;strong&gt;list&lt;/strong&gt; of predictions corresponding to the batch size. And we're having just one batch, so indexing at zero is good.&lt;/p&gt;

&lt;p&gt;Let's take a loss at our sample, shall we?&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Well, a monkey can do better than that. But we haven't run the model yet, so let's make it better.&lt;/p&gt;

&lt;h2&gt;
  
  
  Run the Model with AdamW Optimizer
&lt;/h2&gt;

&lt;p&gt;I will spend a blog to talk about this, but you just need to know that here's the thing that helps us with our learning rate and stuff, so that we can make the model stronger. Let me introduce this man Adam, he's best at optimizing models:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Create optimizer
&lt;/span&gt;
&lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;optim&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;AdamW&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;m&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;1e-3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We will change our implementation a bit, to fit with our Optimizer:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;steps&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt; &lt;span class="c1"&gt;# increase number of steps for good results...
&lt;/span&gt;
    &lt;span class="c1"&gt;# sample a batch of data
&lt;/span&gt;    &lt;span class="n"&gt;xb&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;yb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;get_batch&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="c1"&gt;# evaluate the loss
&lt;/span&gt;    &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;m&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xb&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;yb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;set_to_none&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;You might need to rerun this for numerous time, and at the end, here's what we get:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;2.4956934452056885
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That's significant! Now let's look at the sample again&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;
As thee.
ARKis ve durre wiouro d wimys:
Thaueerrpanat thathe, s
WAnd Fithee.

Nobys
I ld mb, qun mod y wousa darketwave nghof IORitok has whodirate are atit G t hant m,
HAn
CELUSt y XENTha bu.
Wh blinouke sth thiviglecldewist, trveayokeanguror mepes ' wexfalle spprdswhyealaiate foulokiou:
BE an, dop
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;It's having structures! And that's good. Honestly, sometimes I think building this kind of thing is like having a baby, and you get to watch him grows through time, by your dedication and some lessons. I really feel like I'm a responsible father right now!&lt;/p&gt;

&lt;h2&gt;
  
  
  Summarize
&lt;/h2&gt;

&lt;p&gt;We're done with our model! Eventhough we know that the Bigram is bad, and the result is not really beautiful, but throughout this blog, we get to know a lot of new notions! From the sequential and parallel processing of the model (which share some resemblances with the RNN), to the Big Three &lt;code&gt;B,T,C&lt;/code&gt; and the cross-over with CNN. We also know about the sampling method, and creating an optimizer for our model. All in all, our first steps are good.&lt;/p&gt;

&lt;p&gt;Thanks for listening, and having a good day!&lt;/p&gt;

</description>
      <category>deeplearning</category>
      <category>tutorial</category>
      <category>architecture</category>
      <category>beginners</category>
    </item>
    <item>
      <title>BATCHNORM IN LANGUAGE MODELS</title>
      <dc:creator>Hưng Lê Tiến</dc:creator>
      <pubDate>Mon, 24 Nov 2025 09:20:19 +0000</pubDate>
      <link>https://dev.to/blackbidz/batchnorm-in-language-models-539c</link>
      <guid>https://dev.to/blackbidz/batchnorm-in-language-models-539c</guid>
      <description>&lt;p&gt;Welcome to the next chapter! This is a really important topic for LLMs and even DL in general. Batchnorm, alongside with other innovative normalization techniques, is a must-know in Deep Learning. Maybe we should jump right in to the topic, as there is a lot to talk about.&lt;/p&gt;

&lt;p&gt;A note here, this is still the part from the Makemore series of Andrej Karpathy, so shout out to him with a big respect.&lt;/p&gt;

&lt;p&gt;Also, I assume that you already have the code for the previous model, because we won't code everything all over again or make huge modifications, we will just take out the previous one and make some small changes. For the sake of convenience, I won't show you the previous code (it is in the previous chapters, I'm lazy), but rather small parts of it that need to be modified.&lt;/p&gt;

&lt;h2&gt;
  
  
  Problems with the Previous Model
&lt;/h2&gt;

&lt;p&gt;When scrutinizing a newly-invented technique, we should really think about the problems that it solves. Normalization has been linked with a wide range of problems regarding neural network, so let's see &lt;em&gt;why people need to normalize things when playing with neural network&lt;/em&gt;, or a similar question, &lt;em&gt;what are the problems if we let everything behaves freely in the net?&lt;/em&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  The Logits
&lt;/h3&gt;

&lt;p&gt;We are creating the &lt;strong&gt;weights&lt;/strong&gt; and &lt;strong&gt;biases&lt;/strong&gt; in the output  like no one in the neural net game, it's just so bad. First let's look at our first iteration and see the loss that we get:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;    0/ 200000: 27.8817
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;That is TOO FAR away from the minimized loss, and to make it more alarming, we should look at the graph of the loss in each epoch:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffuw0ljh7o5ui9ccn88ek.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ffuw0ljh7o5ui9ccn88ek.png" alt="Lossi"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Look familiar? That is a hockey-stick graph, but for the loss, it shows a disaster: We have a way-too-high loss right from the beginning, and the model struggled at thousands of iterations just to get that unfathomable loss down. Can we prevent that? Right from the beginning?&lt;/p&gt;

&lt;p&gt;Well, actually we can get a little bit of Maths into the initialization process, to put an ease to our model in its very first step, and maybe we can make it even more efficient and arrive at a better loss. Researchers think like that too, and oftentimes, in their process of initializing the parameters, they often have a kind of &lt;strong&gt;Expectation&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;What do we expect from the first step of the model? In the most naive way? Well, we may think that the model treats each character &lt;strong&gt;equally&lt;/strong&gt;, right? Every character has the same probability, and then we can drive up or drive down some of them to make the model stronger. As there are 27 of them, in the output layer each character is supposedly assigned with the probability of &lt;code&gt;1/27&lt;/code&gt;. Now let's see what is the loss given that expectation, which is just the negative log of 1/27 (no magic number or 'aha' moment here, you can come up with that in your mind, it's fairly straightforward):&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;&amp;gt;&amp;gt;&amp;gt;-torch.tensor(1/27.0).log()
tensor(3.2958)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;It's a GREAT loss to start! So why did our model start so poorly? Think about it.&lt;/p&gt;

&lt;p&gt;It turns out that, in the initialization process, when we randomly sample from the Gaussian distribution, there are definitely some cases that some characters are given &lt;em&gt;a higher value&lt;/em&gt; than others while they are not deserved to be so. So it's not a good idea to let that &lt;strong&gt;bias&lt;/strong&gt; to affect us in the first step, and actually, we want everything to be &lt;strong&gt;equally judged&lt;/strong&gt; so that it produces a smaller initial loss.&lt;/p&gt;

&lt;p&gt;Let's see a simple example:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;  &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;(tensor([0.2500, 0.2500, 0.2500, 0.2500]), tensor(1.3863))
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;It's good, but what if we introduce some &lt;strong&gt;extreme values&lt;/strong&gt;?&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;10.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mf"&gt;0.0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mf"&gt;15.0&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;  &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;(tensor([1.3888e-11, 3.0590e-07, 3.0590e-07, 1.0000e+00]), tensor(15.0000))
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;The loss increases &lt;em&gt;significantly&lt;/em&gt;! So the problem here are the &lt;strong&gt;extreme values&lt;/strong&gt;, or in another words (we should really translate our language into math-like things so that we can modify the code), we don't want some values to be &lt;strong&gt;too far away from the average&lt;/strong&gt;. To put it more precisely, we want the &lt;strong&gt;standard deviation&lt;/strong&gt; of the statistics to be really small, like 1 for example.&lt;/p&gt;

&lt;p&gt;Also, the &lt;strong&gt;mean&lt;/strong&gt; should be set around zero, it is useful later and we will discuss about that. Actually, if you read enough ML/DL books or Statistics books, you might be familiar with the: &lt;em&gt;Mean 0, Std 1&lt;/em&gt; kind of stuff. Well, for some reasons, researchers love that thing, and it has a beautiful name which is the &lt;strong&gt;Unit Gaussian Distribution&lt;/strong&gt;. That's exactly what we do when we normalize our data, it's really about making the data like the &lt;strong&gt;Unit Gaussian&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;Now let's turn back too our code, what will we do with the weights and biases to make the values &lt;strong&gt;less extreme&lt;/strong&gt;?&lt;/p&gt;

&lt;p&gt;First, the &lt;code&gt;bias&lt;/code&gt; should be set to zero, as we discuss earlier. Second, the &lt;code&gt;weight&lt;/code&gt; should be smaller, because we don't want some numbers to 'explode' after matrix multiplication, we want to keep the numbers as small as possible. To this end, we multiply &lt;code&gt;0.1&lt;/code&gt; to the &lt;code&gt;W1&lt;/code&gt; matrix, makes sense.&lt;/p&gt;

&lt;p&gt;You may wonder: &lt;em&gt;Why don't we set the weight matrix to 0?&lt;/em&gt; Actually, even though we said that &lt;strong&gt;Everything should be equal&lt;/strong&gt;, but we won't do that in an extreme way. &lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;"It's not good to go extreme, my friend"&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Researcher didn't said that, but they meant that by putting it like: &lt;em&gt;We won't set the matrices at initialization to all zeros because we want some &lt;strong&gt;entropy&lt;/strong&gt; to our loss.&lt;/em&gt; &lt;br&gt;
I won't go deep into this, just think of it as a kind of regularization on the other end.&lt;/p&gt;

&lt;p&gt;Now let's modify our code&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt; &lt;span class="c1"&gt;# the dimensionality of the character embedding vectors
&lt;/span&gt;&lt;span class="n"&gt;n_hidden&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;200&lt;/span&gt; &lt;span class="c1"&gt;# the number of neurons in the hidden layer of the MLP
&lt;/span&gt;
&lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Generator&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2147483647&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# for reproducibility
&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;            &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;b1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;                        &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;          &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;0.01&lt;/span&gt;
&lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;                      &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Notice the &lt;code&gt;*0.01&lt;/code&gt; and the &lt;code&gt;*0&lt;/code&gt; added in the two last matrices (those of the output layer). We shall see our loss now:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;      0/ 200000: 3.3221
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Perfect! Now look at the graph&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F7j4ieqxyqsxgxzqwipot.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F7j4ieqxyqsxgxzqwipot.png" alt="Graph loss"&gt;&lt;/a&gt;&lt;br&gt;
No hockey stick! And maybe our model even performs better, as it doesn't have to waste thousands of initial iterations. But we need to move to another problem in our model.&lt;/p&gt;
&lt;h3&gt;
  
  
  Saturated Tanh &amp;amp; Vanishing Gradient
&lt;/h3&gt;

&lt;p&gt;Logits are not the only thing that went wrong in the process. Let's look at the hidden states, which is simply the &lt;code&gt;tanh&lt;/code&gt; of the embedded chars:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fypagnrldrpf51eng0fcc.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fypagnrldrpf51eng0fcc.png" alt="Tanh"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Woah what happened? The values are kind of polarized into two ends: &lt;code&gt;-1.00&lt;/code&gt; and &lt;code&gt;1.00&lt;/code&gt;. We know that something has gone wrong here, and as some brilliant detectives, we traced one step back, which is the hidden states &lt;em&gt;before the tanh&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F3zzpagqbl1g9ns6p9x8i.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F3zzpagqbl1g9ns6p9x8i.png" alt="Tanh"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;And for the last piece in the picture, let me reminds you of the &lt;code&gt;tanh()&lt;/code&gt; function:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F3kd2dypr5gfrg746274s.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F3kd2dypr5gfrg746274s.png" alt="Tanh"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;code&gt;tanh()&lt;/code&gt; is a &lt;strong&gt;squashing function&lt;/strong&gt;, it squashes and caps every input into the range of &lt;code&gt;1.0&lt;/code&gt; and &lt;code&gt;-1.0&lt;/code&gt;. Seeing the graph of the hidden states before activation, we know exactly what is the problem here: Before activations, there are large numbers, and those numbers enter the &lt;code&gt;tanh()&lt;/code&gt; and later stay at one of the two ends. This process is called &lt;strong&gt;Saturation&lt;/strong&gt;, where values are mostly pushed towards the extreme ends of the function.&lt;/p&gt;

&lt;p&gt;And how does that affect our model exactly? Thinking about the &lt;em&gt;gradient&lt;/em&gt;. The gradient of the &lt;code&gt;tanh()&lt;/code&gt; is actually &lt;code&gt;1-t^2&lt;/code&gt;, with &lt;code&gt;t&lt;/code&gt; being the value of the output of &lt;code&gt;tanh()&lt;/code&gt;, so with every output close to 1 and -1, the gradient will be close to zero! In fact, after backpropagating through many layers of neurons, the gradient can really &lt;strong&gt;shrink to zero&lt;/strong&gt;. That is called &lt;strong&gt;Gradient Vanishing&lt;/strong&gt;, and when no gradients are passed down, things get updated really slowly, hence driving down the efficiency of the model.&lt;/p&gt;

&lt;p&gt;Now let's see what I mean here, just look at the proportion of the hidden states that have the absolute value larger than &lt;code&gt;0.99&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;&lt;span class="mf"&gt;0.99&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;cmap&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gray&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;interpolation&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;nearest&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fcuo46nn89dfhqnhclerg.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fcuo46nn89dfhqnhclerg.png" alt="Hidden states"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The white part is where the gradients zero out and the neurons are dead. You can see that our model is largely saturated.&lt;/p&gt;

&lt;p&gt;So far we've been chilling with our model because it is just a small one. But oftentimes we will have to deal with neural nets with tons of layers, and the problem can accumulate really fast.&lt;/p&gt;

&lt;p&gt;Now let's fix this: We should bring everything close to zero, and make the variance small so that the &lt;code&gt;tanh()&lt;/code&gt; does not push anything to the ends (you can see why this is a good choice by again looking at the graph of the &lt;code&gt;tanh()&lt;/code&gt;). In doing so, we're &lt;strong&gt;normalizing&lt;/strong&gt; the hidden states before activation:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt; &lt;span class="c1"&gt;# the dimensionality of the character embedding vectors
&lt;/span&gt;&lt;span class="n"&gt;n_hidden&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;200&lt;/span&gt; &lt;span class="c1"&gt;# the number of neurons in the hidden layer of the MLP
&lt;/span&gt;
&lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Generator&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2147483647&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# for reproducibility
&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;  &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_embd&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;            &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;0.01&lt;/span&gt;
&lt;span class="n"&gt;b1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;                        &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;0.01&lt;/span&gt;
&lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;          &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mf"&gt;0.01&lt;/span&gt;
&lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;vocab_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;                      &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

&lt;span class="n"&gt;parameters&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;nelement&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="c1"&gt;# number of parameters in total
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;That is the same with the thing we did last time, but in the process, we multiply the &lt;code&gt;b1&lt;/code&gt; by &lt;code&gt;0.01&lt;/code&gt;, not &lt;code&gt;0&lt;/code&gt;, this is, again, an attempt to "&lt;em&gt;introduce some entropy to the model"&lt;/em&gt;. Now let's run and see our hidden states, after and before activation respectively.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fp0umvpsh80o9b1eofc3h.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fp0umvpsh80o9b1eofc3h.png" alt="h"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvn8it5e4shc50tcdkwoy.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fvn8it5e4shc50tcdkwoy.png" alt="hpreact"&gt;&lt;/a&gt;&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# And we should also see if there are some dead neurons
&lt;/span&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;abs&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="o"&gt;&amp;gt;&lt;/span&gt;&lt;span class="mf"&gt;0.99&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;cmap&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gray&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;interpolation&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;nearest&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fj2a1jxu3iszz3nw8y7zj.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fj2a1jxu3iszz3nw8y7zj.png" alt="h.abs()"&gt;&lt;/a&gt;&lt;br&gt;
Everything looks great! The &lt;code&gt;tanh()&lt;/code&gt; doesn't behave badly like before, as we successfully normalize the input, and also, we don't see any saturated values in the last plot (actually we will have to change the scalar of the weight up a little bit so there are few dead neurons, another attempt to prevent going too extreme).&lt;/p&gt;

&lt;p&gt;About the loss, we don't expect really much in this because we're just dealing with a fairly small MLP.&lt;/p&gt;
&lt;h2&gt;
  
  
  Initialization
&lt;/h2&gt;
&lt;h3&gt;
  
  
  A small example
&lt;/h3&gt;

&lt;p&gt;You might question me &lt;em&gt;"How do you know those magic 0.1 numbers that are multiplied with the weights?"&lt;/em&gt;. Actually that is just a random small value (and yet it worked!), but there are numerous research around this, really interesting ones.&lt;/p&gt;

&lt;p&gt;Let's consider a small set of data:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;
&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 
&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;121&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;density&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;subplot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;122&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;hist&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;tolist&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;density&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor(-0.0027) tensor(1.0000)
tensor(0.0031) tensor(3.1318)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F5txv309sthr1foe780p1.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F5txv309sthr1foe780p1.png" alt="Illustration"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;You can see that while the &lt;code&gt;x&lt;/code&gt; is Unit Gaussian, everything messes up after multiplying with &lt;code&gt;w&lt;/code&gt;. Specifically, the &lt;em&gt;mean&lt;/em&gt; is good, it's around zero, but the &lt;em&gt;standard deviation&lt;/em&gt; spells disaster.&lt;/p&gt;

&lt;p&gt;Remember, our goal is to make everything "Unit-Gaussian-like", so the real task here is to &lt;em&gt;find the value of the matrix &lt;code&gt;w&lt;/code&gt; such that the Unit Gaussian is preserved after matrix multiplication&lt;/em&gt;. Now let's try this thing: Divide the weight by &lt;code&gt;sqrt(fanin)&lt;/code&gt;, in which the &lt;em&gt;fanin&lt;/em&gt;, by convention, is the number of input.&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Divide the weights by 1/sqrt(fanin)
&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="mf"&gt;0.5&lt;/span&gt;  
&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt;
&lt;span class="bp"&gt;...&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fwhqeh91c688y7kcvgq74.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fwhqeh91c688y7kcvgq74.png" alt="After dividing"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The output looks beautiful! But how? &lt;/p&gt;
&lt;h3&gt;
  
  
  Kaiming_Init
&lt;/h3&gt;

&lt;p&gt;One of the initialization strategies proposed is the Kaiming Initialization, or the He Initilization. It is also a method in the Pytorch library, named &lt;code&gt;init.kaiming_normal&lt;/code&gt;. This method has the &lt;code&gt;nonlinearity&lt;/code&gt; argument, which specifies the activation function that we would like to use. Note that each activation function has their own &lt;strong&gt;gain&lt;/strong&gt;, which is simply a scalar to combat the "squish" that the activation function makes. The gain for our &lt;code&gt;tanh()&lt;/code&gt; is approximated at 5/3. Also, there is the argument called &lt;code&gt;fan_mode&lt;/code&gt;, we can pass in whether 'fan_in' or 'fan_out', depending on the flow, forward or backward. In our case, we use &lt;code&gt;fan_mode = 'fan_in'&lt;/code&gt;, which is the default value. The standard deviation is calculated following the formula : "std = gain / sqrt(fan_mode)".&lt;/p&gt;

&lt;p&gt;You can see the whole paper &lt;a href="https://arxiv.org/abs/1502.01852" rel="noopener noreferrer"&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;Now let's change our &lt;code&gt;W1&lt;/code&gt; matrix:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;n_embd&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mf"&gt;0.5&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;And that's basically it ! We solved some of the problems regarding neural nets. Note that all of which are just issues involving driving down the weights and even zeroing out the bias, in order to make the data more &lt;strong&gt;Unit Gaussian&lt;/strong&gt;, which is crucial to maintain the &lt;strong&gt;stability&lt;/strong&gt; of the model.&lt;/p&gt;
&lt;h2&gt;
  
  
  Batch Normalization
&lt;/h2&gt;

&lt;p&gt;Now let's talk about the topic today, right? We went such a long way to reach here. Actually this is one of the modern innovations that solve all of our previous problems, and the idea behind is fairly simple. &lt;/p&gt;

&lt;p&gt;If we want the hidden states to be Unit Gaussian, why bothering about initializing and tweaking the weights and bias matrices? Why don't we just &lt;strong&gt;SIMPLY NORMALIZE&lt;/strong&gt; them? And that's how a brilliant idea came in to place.&lt;/p&gt;

&lt;p&gt;You can see the paper &lt;a href="https://arxiv.org/abs/1502.03167" rel="noopener noreferrer"&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;And this is the implementation of it:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fel5e5yra1thtw56ipacj.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fel5e5yra1thtw56ipacj.png" alt="BatchNorm"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;It's just the procedure of normalization: Subtract from the mean, and divide by the standard deviation (which is the square root of the variance). Note that in the diving part (line 3), we have the term &lt;strong&gt;epsilon&lt;/strong&gt; in the denominator: That term is added to prevent the case of zero variance, in which we have to divide things by zero. Usually we have &lt;code&gt;epsilon = 1e-5&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;Another important thing to note here is the last line, we have two parameters called &lt;strong&gt;gamma&lt;/strong&gt; and &lt;strong&gt;beta&lt;/strong&gt;, which is for &lt;em&gt;"scale and shift"&lt;/em&gt;. What is it for?&lt;/p&gt;

&lt;p&gt;Well, we don't really want our data to be &lt;em&gt;exactly Gaussian&lt;/em&gt;, we would like the neural net to move around a bit, to make it more diffuse. Those parameters are trained in the training part, so the model decides how to move the data around. May be it's just some kind of flexibility added to the model.&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# We call those parameters "gain" and "bias", "gain"
# is like the "weight", but it's not a matrix
# In fact, both are 1-dim row vectors
&lt;/span&gt;
&lt;span class="c1"&gt;# In the initialize part, we expect to keep things unchanged
# So everything is multiplied by 1 and added with zero
&lt;/span&gt;
&lt;span class="n"&gt;bngain&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;bnbias&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Also, remember to include these things in the &lt;code&gt;parameters&lt;/code&gt;, as those are trainable:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Include in the parameters
&lt;/span&gt;&lt;span class="n"&gt;parameters&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bngain&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;bnbias&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Now let's modify our model a bit: Before applying the &lt;code&gt;tanh()&lt;/code&gt; function, we normalize the hidden states:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# same optimization as last time
&lt;/span&gt;&lt;span class="n"&gt;max_steps&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;200000&lt;/span&gt;
&lt;span class="n"&gt;batch_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;32&lt;/span&gt;
&lt;span class="n"&gt;lossi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;max_steps&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

  &lt;span class="c1"&gt;# minibatch construct
&lt;/span&gt;  &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Xtr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;,),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;Xb&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Yb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;Xtr&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;Ytr&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# batch X,Y
&lt;/span&gt;
  &lt;span class="c1"&gt;# forward pass
&lt;/span&gt;  &lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Xb&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# embed the characters into vectors
&lt;/span&gt;  &lt;span class="n"&gt;embcat&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# concatenate the vectors
&lt;/span&gt;  &lt;span class="c1"&gt;# Linear layer
&lt;/span&gt;  &lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embcat&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt; &lt;span class="c1"&gt;# hidden layer pre-activation
&lt;/span&gt;  &lt;span class="c1"&gt;# Batch Norm
&lt;/span&gt;  &lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;bngain&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;bnbias&lt;/span&gt;
  &lt;span class="c1"&gt;# Non-linearity
&lt;/span&gt;  &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# hidden layer
&lt;/span&gt;  &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# output layer
&lt;/span&gt;  &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Yb&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# loss function
&lt;/span&gt;
  &lt;span class="c1"&gt;# backward pass
&lt;/span&gt;  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;
  &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

  &lt;span class="c1"&gt;# update
&lt;/span&gt;  &lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt; &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;&amp;lt;&lt;/span&gt; &lt;span class="mi"&gt;100000&lt;/span&gt; &lt;span class="k"&gt;else&lt;/span&gt; &lt;span class="mf"&gt;0.01&lt;/span&gt; &lt;span class="c1"&gt;# step learning rate decay
&lt;/span&gt;  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;

  &lt;span class="c1"&gt;# track stats
&lt;/span&gt;  &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="o"&gt;%&lt;/span&gt; &lt;span class="mi"&gt;10000&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="c1"&gt;# print every once in a while
&lt;/span&gt;    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;7&lt;/span&gt;&lt;span class="n"&gt;d&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;/&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;max_steps&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="mi"&gt;7&lt;/span&gt;&lt;span class="n"&gt;d&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;lossi&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log10&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Notice the &lt;code&gt;Batch Norm&lt;/code&gt; layer in the middle of the &lt;code&gt;Linear&lt;/code&gt; and the &lt;code&gt;Non-linearity&lt;/code&gt; layers.&lt;/p&gt;

&lt;p&gt;Batch Norm should also be used in the &lt;strong&gt;test set&lt;/strong&gt;, or in the evaluation mode.&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="nd"&gt;@torch.no_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;split_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Xtr&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Ytr&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;val&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Xdev&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Ydev&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
    &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;test&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Xte&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Yte&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
  &lt;span class="p"&gt;}[&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# (N, block_size, n_embd)
&lt;/span&gt;  &lt;span class="n"&gt;embcat&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# concat into (N, block_size * n_embd)
&lt;/span&gt;  &lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embcat&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt;  &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;

  &lt;span class="c1"&gt;# Adding the BatchNorm layer here
&lt;/span&gt;  &lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;bngain&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;bnbias&lt;/span&gt;
  &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (N, n_hidden)
&lt;/span&gt;  &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# (N, vocab_size)
&lt;/span&gt;  &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

&lt;span class="nf"&gt;split_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;train&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;split_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;val&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;In fact, when we're dealing with numerous hidden layers in the neural net, we should sprinkle BatchNorm across the whole net, right after each layer. &lt;/p&gt;

&lt;p&gt;So far we've been praising this technique really much. But does it live up to the hype?&lt;/p&gt;
&lt;h3&gt;
  
  
  Some issues with BatchNorm
&lt;/h3&gt;

&lt;p&gt;First just look at the name of the technique: &lt;em&gt;BatchNorm&lt;/em&gt;, and yes, we normalize our data, but we do that &lt;strong&gt;in batches&lt;/strong&gt;. What is the problem with that?&lt;/p&gt;

&lt;p&gt;Our batches are generated in a &lt;strong&gt;purely random&lt;/strong&gt; way, and then we pass our data in &lt;em&gt;groups&lt;/em&gt;. In our previous model, it still works because when we pass the data in, we don't do anything with the &lt;em&gt;group that each instance belongs to&lt;/em&gt;, but rather we treat each one &lt;strong&gt;independently&lt;/strong&gt;, so eventhough we're dealing with batches, we still hold the independence property of data. But look at what we do when applying BatchNorm: We normalize the data &lt;em&gt;with respect to the their batches&lt;/em&gt;, so actually the value of data is &lt;strong&gt;dependent&lt;/strong&gt; on others in the same group.&lt;/p&gt;

&lt;p&gt;That is rather an ugly thing, as it is super unnatural to work with groups like that, and we're even generating the groups randomly. But, a big moment here, it turns out that in practice, this is actually &lt;strong&gt;a regularization technique&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;By feeding in random batches of instances, the model gets affected largely when new batches come, which is considered as a kind of &lt;strong&gt;noise&lt;/strong&gt;. Hence, the model becomes more robust to tiny shifts, as well as it will learn features rather than memorize, and that means regularization. (Actually I'm not really clear about this, so you can search on Google or ask ChatGPT for better explanation)&lt;/p&gt;

&lt;p&gt;Through time, there's been numerous normalizing techniques (like the LayerNorm) that have been developed to replace this BatchNorm, mainly because people don't like the "grouping bug" of this thing. But let's just stick with this amazing normalization technique in this blog post.&lt;/p&gt;
&lt;h3&gt;
  
  
  Problems with Inference/ Evaluation
&lt;/h3&gt;

&lt;p&gt;This is a hard-to-grasp concept, and it took me hours to understand. I will do my best to give you a good sense of what this is.&lt;/p&gt;

&lt;p&gt;Remember in the test mode, we included the &lt;code&gt;hpreact.mean&lt;/code&gt; and the &lt;code&gt;hpreact.std&lt;/code&gt;? But what do you think our model are doing when making predictions? It should make predictions based on &lt;strong&gt;only that input, independently&lt;/strong&gt;, right? And again, what our model is doing is rather unnatural and strange, as somehow we're taking the &lt;strong&gt;batches&lt;/strong&gt; into consideration when predicting a single instance.&lt;/p&gt;

&lt;p&gt;So, in order to preserve the &lt;em&gt;independence&lt;/em&gt; of the data when evaluating, we should have an &lt;strong&gt;population mean and std&lt;/strong&gt; to normalize our data. One way to implement this is to wait until the end of the training and run the whole training set again, but this time we calculate the mean and std of the population.&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# calibrate the batch norm at the end of training
&lt;/span&gt;
&lt;span class="c1"&gt;# Remember to use torch.no_grad because we won't backpropagate through this during training, this will save a lot of memory
&lt;/span&gt;&lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
  &lt;span class="c1"&gt;# pass the training set through
&lt;/span&gt;  &lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Xtr&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="n"&gt;embcat&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;embcat&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="c1"&gt;# + b1
&lt;/span&gt;  &lt;span class="c1"&gt;# measure the mean/std over the entire training set
&lt;/span&gt;  &lt;span class="n"&gt;bnmean&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;bnstd&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;And then in the evaluation mode, we replace the mean and std of the batch with the population mean and std:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;bngain&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;bnmean&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;bnstd&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;bnbias&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;That would solve the problem!&lt;/p&gt;

&lt;p&gt;But researcher would not like that, training the whole dataset again just to find the mean and the std is rather laborious, and actually in practice they have a nicer way to &lt;em&gt;approximate&lt;/em&gt; those values, right during the training part.&lt;/p&gt;

&lt;p&gt;The technique here is the keep a &lt;strong&gt;exponential moving average&lt;/strong&gt; (EMA) throughout training. We constantly update the values of mean and std when each batch arrives, and the level of update is determined by the hyperparameter &lt;strong&gt;momentum&lt;/strong&gt;, which is the &lt;strong&gt;beta&lt;/strong&gt; in the formula below:&lt;/p&gt;

&lt;p&gt;

&lt;/p&gt;
&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;μrunning=β.μrunning+(1−β).μB
\mu_{\text{running}} = \beta.\mu_{\text{running}} + (1-\beta).\mu_B
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;μ&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord text mtight"&gt;&lt;span class="mord mtight"&gt;running&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;β&lt;/span&gt;&lt;span class="mord"&gt;.&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;μ&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord text mtight"&gt;&lt;span class="mord mtight"&gt;running&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;+&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;1&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;−&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;β&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mord"&gt;.&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;μ&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;B&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;



&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;σrunning2=β.σrunning2+(1−β).σB2
\sigma_{\text{running}}^2 = \beta.\sigma_{\text{running}}^2 + (1-\beta).\sigma_B^2
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;σ&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord text mtight"&gt;&lt;span class="mord mtight"&gt;running&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;β&lt;/span&gt;&lt;span class="mord"&gt;.&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;σ&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord text mtight"&gt;&lt;span class="mord mtight"&gt;running&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;+&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;1&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mbin"&gt;−&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;β&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mord"&gt;.&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;σ&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;B&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;



&lt;h3&gt;
  
  
  Implementing the Running Mean and Standard Deviation
&lt;/h3&gt;

&lt;p&gt;Now let's turn back to our code for the (nearly) last optimization. When training out model, we have two jobs to do:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;We need to calculate the mean and std for each batch, and then we normalize the values just like we did when implementing BatchNorm, this is for training.&lt;/li&gt;
&lt;li&gt;We also have to update two values during this process, which are later used in testing (or inference): The &lt;code&gt;bnmean_running&lt;/code&gt; and the &lt;code&gt;bnstd_running&lt;/code&gt;, calculated by the &lt;em&gt;exponential moving average&lt;/em&gt;. Remember that these are not the parameters, so set the &lt;code&gt;torch.no_grad()&lt;/code&gt; to them.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;With everything prepared, we should jump right in the code. First we initialize the running mean and std.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Initialize the running mean and std
&lt;/span&gt;&lt;span class="n"&gt;bnmean_running&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;bnstd_running&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;ones&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_hidden&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now let's change the BatchNorm layer a bit:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# BatchNorm layer
&lt;/span&gt;  &lt;span class="c1"&gt;# -------------------------------------------------------------
&lt;/span&gt;  &lt;span class="c1"&gt;# This is for normalization
&lt;/span&gt;  &lt;span class="n"&gt;bnmeani&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;bnstdi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;std&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;bngain&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;bnmeani&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;bnstdi&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;bnbias&lt;/span&gt;

 &lt;span class="c1"&gt;# This is later used for testing, notice the torch.no_grad()
&lt;/span&gt;  &lt;span class="c1"&gt;# We set the momentum = 0.999
&lt;/span&gt;  &lt;span class="k"&gt;with&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;no_grad&lt;/span&gt;&lt;span class="p"&gt;():&lt;/span&gt;
    &lt;span class="n"&gt;bnmean_running&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.999&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;bnmean_running&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mf"&gt;0.001&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;bnmeani&lt;/span&gt;
    &lt;span class="n"&gt;bnstd_running&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.999&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;bnstd_running&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mf"&gt;0.001&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;bnstdi&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now we should take those running mean and std into our evaluation mode:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="nd"&gt;@torch.no_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# this decorator disables gradient tracking
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;split_loss&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="bp"&gt;...&lt;/span&gt;
  &lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;bngain&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hpreact&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;bnmean_running&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;bnstd_running&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;bnbias&lt;/span&gt;
  &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;hpreact&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (N, n_hidden)
&lt;/span&gt;  &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# (N, vocab_size)
&lt;/span&gt;  &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;split&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And that's basically it! We can see our approximation of mean and std, compared with the real bnmean and bnstd we calculated earlier:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;&amp;gt;&amp;gt;&amp;gt;bnmean_running
tensor([[-2.3338,  0.6988, -0.9011,  0.9966,  1.0906,  1.0759,  1.7426, -2.1253,...
&amp;gt;&amp;gt;&amp;gt;bnmean
tensor([[-2.3145,  0.6885, -0.9134,  0.9972,  1.0878,  1.0841,  1.7470, -2.1102,...
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The approximations were really good! And let's see our final training and validation loss:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;train 2.0666308403015137
val 2.1051523685455322
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Great!&lt;/p&gt;

&lt;h3&gt;
  
  
  Removing the Bias
&lt;/h3&gt;

&lt;p&gt;One small thing to note here, maybe the last thing, is that we can actually &lt;strong&gt;remove the bias&lt;/strong&gt; if we're gonna applying BatchNorm. This is simply because when normalizing, we don't really care about our data being offset by some value, because we subtract everything by the mean anyway. (Think of the bias as something that make the distribution move around left and right, and it is meaningless in normalization because we will eventually move our distribution to center around zero)&lt;/p&gt;

&lt;p&gt;So let's just comment out the initialization of &lt;code&gt;b1&lt;/code&gt;, which contributes to better memory usage:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;#b1 = torch.randn(n_hidden, generator=g) * 0.01
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;If you scrutinize some advanced models using BatchNorm, like the Resnet, you will notice that when creating a Linear layer, they all set the &lt;code&gt;bias = False&lt;/code&gt;.&lt;/p&gt;

&lt;h2&gt;
  
  
  Summarize
&lt;/h2&gt;

&lt;p&gt;When I first learned about BatchNorm in the book, I didn't even know why we need to normalize stuff, and I got so confused at that time. But now when actually implementing a model of my own (well it's from Andrej, sorry if it bothers you), I realized that BatchNorm is a really significant milestone that solves a lot of problems in the neural net, and those are really easy to notice.&lt;/p&gt;

&lt;p&gt;Today we fixed a lot of bugs our models, as well as bringing the holy BatchNorm to light, we went from some problems with Initialization, the Vanishing Gradient, to BatchNorm and its implementation, its pros and cons, as well as learning about the EMA, removing the bias,...&lt;/p&gt;

&lt;p&gt;Well, another productive day of us! Thanks for reading and have a good day. &lt;/p&gt;

</description>
      <category>deeplearning</category>
      <category>llm</category>
      <category>tutorial</category>
      <category>machinelearning</category>
    </item>
    <item>
      <title>LANGUAGE MODELS USING MLP (Part 2)</title>
      <dc:creator>Hưng Lê Tiến</dc:creator>
      <pubDate>Thu, 20 Nov 2025 22:21:42 +0000</pubDate>
      <link>https://dev.to/blackbidz/language-models-using-mlp-part-2-g8a</link>
      <guid>https://dev.to/blackbidz/language-models-using-mlp-part-2-g8a</guid>
      <description>&lt;p&gt;Welcome back! Today we will train our multi-layer perceptron, as well as exploring some techniques to fine-tune the model. Remember the terrifying loss that we have in our previous chapter? Keep that in mind, we will try our best to minimize that loss in this blog.&lt;/p&gt;

&lt;h2&gt;
  
  
  CROSS-ENTROPY
&lt;/h2&gt;

&lt;p&gt;We first introduce to you a new evaluation metric called &lt;code&gt;cross_entropy&lt;/code&gt;. Let's see what it actually is:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor(14.3920)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;It is the same with our negative log likelihood! In fact, the &lt;code&gt;cross-entropy&lt;/code&gt; is just a short implementation of our loss. Rather than hard-coding all of the intermediate procedures, we wrap everything up on one line of code, after calculating our logits:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Calculating the logits
&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt;

&lt;span class="c1"&gt;# Our previous code
&lt;/span&gt;&lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;prob&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;keepdims&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;  &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="c1"&gt;# New version
&lt;/span&gt;&lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The function &lt;code&gt;cross_entropy&lt;/code&gt; actually offers convenience beyond that:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;By eliminating the intermediate steps, we also &lt;strong&gt;save tons of memory&lt;/strong&gt;. To be clear, in our previous code, we created numerous intermediate tensors for storing the &lt;code&gt;counts&lt;/code&gt;, then the &lt;code&gt;prob&lt;/code&gt;, then eventually the &lt;code&gt;loss&lt;/code&gt;, that itself represents a waste of memory. In the &lt;code&gt;cross_entropy&lt;/code&gt; function, however, there's no intermediate steps (at least that's what I was told), so there's a huge saving in the memory.&lt;/li&gt;
&lt;li&gt;The function also allows for more &lt;strong&gt;efficient backward pass&lt;/strong&gt;. We don't have to pass through numerous intermediate functions like in our previous code, instead, we just need to backward through a single computation in the &lt;code&gt;cross_entropy&lt;/code&gt;, which is more efficient and time-saving.&lt;/li&gt;
&lt;li&gt;It is more &lt;strong&gt;numerically well-behave&lt;/strong&gt;. This is the thing that we really need to dive into. So let us spend some minutes talking about what is &lt;em&gt;numerically well-behaved&lt;/em&gt;.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Let's consider the prob of a random tensor:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;
&lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;probs&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([1.5123e-08, 3.3311e-04, 6.6906e-03, 9.9298e-01])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;So far so good, no problems here. But let's see what happens when we have some &lt;strong&gt;positive values&lt;/strong&gt; in our tensor:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;13&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;
&lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;probs&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([0., 0., 0., nan])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;There's something wrong! Further scrutinize the counts, we will find that:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;&amp;gt;&amp;gt;&amp;gt;counts
tensor([5.6028e-09, 1.2341e-04, 2.4788e-03,        inf])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The count for the value &lt;code&gt;100&lt;/code&gt; goes to infinity! And that is understandable because previously we took the &lt;em&gt;exponential&lt;/em&gt; of each values, so when it encounters some huge positive numbers, it will have the tendency to blow up, which is often called the &lt;strong&gt;numerical overflow&lt;/strong&gt; (also note that it performs well with the negative numbers). &lt;/p&gt;

&lt;p&gt;We don't want that bug to appear in our project, so we need to find a way to &lt;strong&gt;scale down&lt;/strong&gt; all of the values. This involves a subtle detail in the process of constructing the softmax: &lt;em&gt;The output doesn't change when we offset our logits by a constant value&lt;/em&gt;. Take a moment to think about that, it is fairly easy to explain mathematically. &lt;/p&gt;

&lt;p&gt;With that in mind, then the simplest thing we can do to solve the problem is just &lt;strong&gt;subtract the maximum value&lt;/strong&gt; in the whole tensor. From then, we will have a bunch of negative numbers and we're happy to deal with that. &lt;/p&gt;

&lt;p&gt;Turning back to our function &lt;code&gt;cross_entropy&lt;/code&gt;, the operation we mentioned above is conveniently included in the function itself, so the function actually does wonders for us in terms of convenience.&lt;/p&gt;

&lt;p&gt;Now we shall turn back to our project. &lt;/p&gt;

&lt;h2&gt;
  
  
  Putting Things Together
&lt;/h2&gt;

&lt;p&gt;During our course of implementing our project, we initializing tensors and functions in a kind of a thinking flow, and thus our code is rather messy and unorganized. In practice, organizing the project is crucial for &lt;em&gt;readability&lt;/em&gt; and &lt;em&gt;debugging&lt;/em&gt; later on. Moreover, when we want to make modifications to our model, we would know exactly where to go, so if you're working on your project, don't be lazy, take time to put things together as it will be immensely useful later on.&lt;/p&gt;

&lt;h3&gt;
  
  
  Initializing all the parameters
&lt;/h3&gt;

&lt;p&gt;We put all of the matrices that we have to initialize in a single block, and we also define the &lt;code&gt;seed&lt;/code&gt; for the Generator, so that we can reproduce our code.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Generator&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2147483647&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# for reproducibility
&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;b1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;parameters&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now we shall build our model&lt;/p&gt;

&lt;h3&gt;
  
  
  Build the model
&lt;/h3&gt;

&lt;p&gt;First, remember to set &lt;code&gt;requires_grad  = True&lt;/code&gt; so that we can implement our backward pass.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And then we build the model with forward and backward pass, we also need to update the parameters:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

  &lt;span class="c1"&gt;# forward pass
&lt;/span&gt;  &lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# (32, 3, 2)
&lt;/span&gt;  &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;30&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (32, 100)
&lt;/span&gt;  &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# (32, 27)
&lt;/span&gt;  &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

  &lt;span class="c1"&gt;# backward pass
&lt;/span&gt;  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;
  &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

  &lt;span class="c1"&gt;# update
&lt;/span&gt;  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Run the code
&lt;/h3&gt;

&lt;p&gt;So now we shall run the code and see what happens. Here's the last few results:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;0.26593196392059326
0.26574623584747314
0.265565425157547
0.26538926362991333
0.2652176320552826
0.2650502920150757
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Astonishing? No, remember that the data is small (only 32 windows), and when the model arrives at an extremely small loss, we won't praise in, but we would rather think that the model is &lt;strong&gt;overfitting&lt;/strong&gt;. &lt;/p&gt;

&lt;p&gt;Overfitting means that the model tries to memorize the data rather than capture the underlying meaning. So even though the model is extremely good in their own training set, it would perform poorly on the real-world data, where it gets to see things that it haven't learnt before, and it cannot do anything with it since everything it've done so far is pure memorization (imagine this as a student studying in a class and then taking a test, then you can really understand what I mean).&lt;/p&gt;

&lt;p&gt;But why the loss doesn't converge to 0? If the model gets to memorize everything in the training set, so why it doesn't have the 100% accuracy? Normally, if the model is overfitting like this, we will surely get a loss of zero. However, specifically in this project, there is a subtle detail that prevents our network from guessing accurately.&lt;/p&gt;

&lt;p&gt;Let's look at our data again:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;... ---&amp;gt; e
..e ---&amp;gt; m
.em ---&amp;gt; m
emm ---&amp;gt; a
mma ---&amp;gt; .
... ---&amp;gt; o
..o ---&amp;gt; l
.ol ---&amp;gt; i
oli ---&amp;gt; v
liv ---&amp;gt; i
ivi ---&amp;gt; a
via ---&amp;gt; .
... ---&amp;gt; a
..a ---&amp;gt; v
.av ---&amp;gt; a
ava ---&amp;gt; .
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Can you notice the key here? Take a guess.&lt;/p&gt;

&lt;p&gt;So here's the answer: Look at the instance &lt;code&gt;...&lt;/code&gt;, denoting the start of word, you just cannot guess the character that comes after it! Try it yourself. It is totally random. So that's where our model struggled. &lt;/p&gt;

&lt;h3&gt;
  
  
  Working with the whole dataset
&lt;/h3&gt;

&lt;p&gt;Lately we've been dealing with mere few words in the dataset, so let's go big this time.&lt;/p&gt;

&lt;p&gt;From the code that constructs our dataset, we should remove the &lt;code&gt;words[:5]&lt;/code&gt; and pass in everything.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# rebuild the dataset
&lt;/span&gt;&lt;span class="n"&gt;block_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt; 
&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

&lt;span class="c1"&gt;# Now we iterate through all the words
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="c1"&gt;#print(''.join(itos[i] for i in context), '---&amp;gt;', itos[ix])
&lt;/span&gt;    &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# crop and append
&lt;/span&gt;
&lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;Y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now let's look at the size of our data&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;torch.Size([228146, 3]) torch.Size([228146])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That's huge. &lt;/p&gt;

&lt;p&gt;Also, when calculating the loss make sure to remove the number &lt;code&gt;32&lt;/code&gt; and replace it by the appropriate dimension. That's actually the reason why we always try to prevent hard-coding some numbers, it will save a lot of time going around and fixing minor bugs.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;  &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now we should run our model and see what happens, I recommend you to set the number of iterations to 10, we will discuss shortly about that.&lt;br&gt;
Here's the results:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;19.505226135253906
17.08449363708496
15.776531219482422
14.833340644836426
14.002603530883789
13.253260612487793
12.57991886138916
11.983101844787598
11.47049331665039
11.051856994628906
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;When running the code, it turns out the the model become slower, and that's understandable because it has to calculate the gradient for the &lt;strong&gt;whole dataset of 220000 instances&lt;/strong&gt;, it's really a big deal. And we're just iterating 10 times, so this approach is really questionable when scaling.&lt;/p&gt;

&lt;p&gt;A clever way to combat this is passing the input in &lt;strong&gt;mini batches&lt;/strong&gt;, rather than the whole batch of data. So we're dealing with a group of instances at a time, not every single one. This approach is, surely, not as accurate as the batch ones, but it is more time-saving and puts less strain on the computer, and actually it is widely used in practice. In other words, it's much better to approximate the gradient and take much more steps, rather than taking the exact gradient and fewer steps. &lt;/p&gt;

&lt;p&gt;So let us construct the mini-batch: We will randomly initialized tensors with 32 values ranging from zero to our number of training instances. Here's how it goes:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([200670, 191458, 142413, 156993, 217108, 174176, 143298,  30653, 148878,
        158381,  11828,  75183, 115824,  49455,  91737, 216958, 142564, 224086,
         73948, 217108, 174951, 170926, 180371, 224631, 167595, 173195, 116182,
        192239, 158702,  43879,  45633, 165950])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That looks tasty. Let includes that in our code, and modify our model a little bit so that it will just take out 32 values no more no less.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="c1"&gt;# mini batch construct
&lt;/span&gt;  &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
  &lt;span class="c1"&gt;# forward pass
&lt;/span&gt;  &lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt; &lt;span class="c1"&gt;# (32, 3, 2)
&lt;/span&gt;  &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (32, 100)
&lt;/span&gt;  &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# (32, 27)
&lt;/span&gt;  &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
  &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

  &lt;span class="c1"&gt;# backward pass
&lt;/span&gt;  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;
  &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

  &lt;span class="c1"&gt;# update
&lt;/span&gt;  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mf"&gt;0.1&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Here's some few last loss:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;3.3569700717926025
3.668323516845703
3.5134356021881104
3.7804458141326904
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Promising! And it takes just a few seconds to arrive at that loss. &lt;/p&gt;

&lt;h2&gt;
  
  
  Learning Rate
&lt;/h2&gt;

&lt;p&gt;Let's look at the implementation again, there's actually one thing that we haven't talk about yet. It's the number &lt;code&gt;0.1&lt;/code&gt; we multiply by when we update the parameters. It is called the &lt;strong&gt;learning rate&lt;/strong&gt;, and we will spend time discussing about it in this section.&lt;/p&gt;

&lt;p&gt;The learning rate, as you can easily notice by yourself, is simply a number to control the &lt;em&gt;length&lt;/em&gt; of our step. If the learning rate is too low, which means that we take small steps, then it will take a huge number of iterations to converge to the minimum. By contrast, if the learning rate is too high, indicating big step size, then we will potentially "slip" through the minimum and end up bouncing back and forth without really converging but rather &lt;em&gt;diverge&lt;/em&gt;. If it is hard for you to visualize, here's a cool visualization that might help:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fdir3bb7f85nm23nqs2j0.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fdir3bb7f85nm23nqs2j0.png" alt="Learning rate" width="800" height="310"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;*Note: Actually the image is from a really dedicated blog post about learning rate (100x better than mine), I recommend you to read the blog as it is so deep-dive into the topic. &lt;a href="https://www.jeremyjordan.me/nn-learning-rate/" rel="noopener noreferrer"&gt;Here's the link.&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  Learning rate decay
&lt;/h3&gt;

&lt;p&gt;There are many ways to set up the learning rate during training in order to maximize the power of the model. The simplest one, which we will do in our project, is perhaps the &lt;strong&gt;step decay&lt;/strong&gt; (or the &lt;strong&gt;learning rate decay&lt;/strong&gt;). The technique involves decreasing the learning rate after a fixed number of epochs (like 100000), thus making the model move slower when the loss is small and potentially near the minimum. This is kind of intuitive: we will take big steps when we're far from the min, and smaller ones when we're close.&lt;/p&gt;

&lt;p&gt;Even though there are numerous techniques for step decay, we will still keep it simple by just dividing the learning rate by 10 after each epoch of 10000 iterations.&lt;/p&gt;

&lt;h3&gt;
  
  
  Find the right learning rate
&lt;/h3&gt;

&lt;p&gt;Now we now the importance of choosing the right learning rate and adjust it so that the model performs well. But how exactly can we do that? &lt;/p&gt;

&lt;p&gt;We will follow Andrej's way in finding the good learning rate, actually it's not the most optimal but it's really intuitive and we can easily follow all the steps. First we have to determine a &lt;strong&gt;range&lt;/strong&gt; of possible good values. There would be some points where the learning rate is too low and the model is not really good at decreasing the loss, and there would be points where the learning rate is too high and the loss is bouncing around. We might see those points when we plug in different values, this is pure trial and error. And after playing around with some values of the learning rate (actually Andrej did this in his video, not me), we have our range: from &lt;code&gt;0.001&lt;/code&gt; to &lt;code&gt;1&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;The next step is examining every values in each range and see what happens with the loss, we should have our optimal value when the loss is at its minimum. So let's run the model with different values of learning rate, actually we can use Pytorch to construct an array of values in that range&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Want to examine the range from 0.001 to 1
&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.001&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Plotting the lr-loss Curve
&lt;/h3&gt;

&lt;p&gt;After having the array we want, let's plug it in our model. Also, we need to construct two lists, storing values of learning rate and the loss respectively, in order to graph things out.&lt;br&gt;
&lt;strong&gt;A BIG NOTE:&lt;/strong&gt; After messing things around in my own project, I have to warn you this so that you won't have this same silly mistake: &lt;em&gt;REMEMBER TO RERUN THE ENTIRE MODEL AGAIN&lt;/em&gt;. Don't use the previous Weights and Bias for rerun, you have to INITIALIZE THEM ALL OVER AGAIN. Keep this in mind for every time we rerun the model.&lt;br&gt;
Here's the code:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Keep track of lr and loss for each iter
&lt;/span&gt;&lt;span class="n"&gt;lri&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="n"&gt;lossi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;requires_grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="c1"&gt;# mini batch construct
&lt;/span&gt;  &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;
  &lt;span class="c1"&gt;# forward pass
&lt;/span&gt;  &lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt; &lt;span class="c1"&gt;# (32, 3, 2)
&lt;/span&gt;  &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (32, 100)
&lt;/span&gt;  &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# (32, 27)
&lt;/span&gt;  &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
  &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

  &lt;span class="c1"&gt;# backward pass
&lt;/span&gt;  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;
  &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

  &lt;span class="c1"&gt;# update
&lt;/span&gt;  &lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;lrs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;
  &lt;span class="c1"&gt;# tracking stats
&lt;/span&gt;  &lt;span class="n"&gt;lri&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;lossi&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Here comes the plot:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Plot the lr
&lt;/span&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;lri&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;lossi&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fgc42fu0mj361e377fksy.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fgc42fu0mj361e377fksy.png" alt="LR-LOSS plot" width="556" height="418"&gt;&lt;/a&gt;&lt;br&gt;
This is like a hockey-stick-alike kind of graph, means that it has a steep slope at the origin (nearly like a straight line), and then having small fluctuations along the middle part before bouncing back and forth uncontrollably till the end of the range. &lt;/p&gt;

&lt;p&gt;Let's interpret this graph, and you will see that our conclusions so far are all correct. First, notice in the beginning of the range, where the learning rate is small, the loss is not even close to the minimum, indicate that we're not making large enough steps to converge at the minimum. Conversely, in the bulk of the remaining range, where the learning rate increases, the loss bounces between values and it bounces even stronger when the learning rate is near 1, which is an indicator of divergence. &lt;/p&gt;

&lt;p&gt;Now our mission is looking at this graph and find the value where the learning rate gets the loss to its minimum. From our eyes we can see that the value is in somewhere around &lt;code&gt;0.1&lt;/code&gt;. If it's not too obvious, then we can plot out the &lt;strong&gt;log&lt;/strong&gt; of the learning rate, so that the range of values can be wider. Let's implement that in our code:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Take the log of our range of values
&lt;/span&gt;
&lt;span class="n"&gt;lre&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;linspace&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;3&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1000&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;lrs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt; &lt;span class="o"&gt;**&lt;/span&gt; &lt;span class="n"&gt;lre&lt;/span&gt;

&lt;span class="c1"&gt;# Intilize a new list 
&lt;/span&gt;&lt;span class="n"&gt;lrlog&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

&lt;span class="c1"&gt;# Update this list in when tracking stats
&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="bp"&gt;...&lt;/span&gt;
  &lt;span class="n"&gt;lrlog&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;lre&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
  &lt;span class="n"&gt;lri&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;lossi&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

&lt;span class="c1"&gt;# Plotting things out
&lt;/span&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;lrlog&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;lossi&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And here's the plot:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Favvvz27vse10niyz45yt.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Favvvz27vse10niyz45yt.png" alt="LRLOG-LOSS curve" width="543" height="418"&gt;&lt;/a&gt;&lt;br&gt;
Now it's getting obvious. So we go full circle in this section: The magic number &lt;code&gt;0.1&lt;/code&gt; in the beginning is actually the optimal learning rate for this model.&lt;/p&gt;

&lt;p&gt;And after rerun the model with this optimal learning rate, plus the &lt;strong&gt;learning rate decay&lt;/strong&gt; technique, we have our loss:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;2.3681914806365967
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;It's the best result that we've got so far! It's amazing that just this one small number can create such a huge effect on our model, even when we're playing with the simple methods (guess how crazy it will be when you try some advanced ones, worth to try!).&lt;/p&gt;

&lt;h2&gt;
  
  
  Splitting the Data
&lt;/h2&gt;

&lt;p&gt;So far we've been getting the model to study a lot, and we also fine-tune it in order to bring down the loss. But does this loss mean that our model will do well in real-world setting? Actually we don't really know, because there's no &lt;strong&gt;test&lt;/strong&gt; yet. How can we create a test for the model?&lt;/p&gt;

&lt;p&gt;The answer is that we will stop feeding the model every data points, instead, we will hold back a small portion of the dataset. That's gonna be our test, because the model doesn't get to see those data during training. Normally the split is around &lt;code&gt;8:2&lt;/code&gt; or &lt;code&gt;7:3&lt;/code&gt; for the training: testing. The evaluation method in the test set is still our &lt;code&gt;cross_entropy&lt;/code&gt;. One thing to keep in mind: If the validation is &lt;em&gt;high&lt;/em&gt; while the training loss is &lt;em&gt;low&lt;/em&gt; then there's a high chance of &lt;em&gt;overfitting&lt;/em&gt; the data (The model does well only when it comes to the data it has seen before while performing poorly in the test, indicating that it is &lt;strong&gt;memorizing&lt;/strong&gt; rather than &lt;strong&gt;generalizing&lt;/strong&gt;).&lt;/p&gt;

&lt;p&gt;Moreover, when training, we also want to hold out a bit of data in order to fine-tune the model, it's like another test set but not for test, but for evaluations to find the good hyperparameters (just like what we've done with the learning rate). &lt;em&gt;"But why another test set? Why don't we evaluate the hyperparameters right on the training set just like what we've done lately?"&lt;/em&gt; - you may ask. The thing is, you're tweaking and tuning the hyperparameters so that they will make the model powerful &lt;strong&gt;with just the training data&lt;/strong&gt;, there's no guarantee that these hyperparameters will work in practice, with data that it's never seen before. So we create another set of training called the &lt;strong&gt;dev_set&lt;/strong&gt;, designed specifically for finding the hyperparameters.&lt;/p&gt;

&lt;p&gt;Another side note here, what happens if the data is too small to split into 3 sets? How can we accurately evaluate the model? The solution for this is quite interesting: If the data is too small, we have to "reuse" it when testing. To be more specific, rather than having just one specific split and evaluating the model on that, we will create a kind of parallel universe where each one possesses a different split, then we evaluate the model and get the big picture. This clever method is &lt;strong&gt;cross-validation&lt;/strong&gt;, which is widely used in ML tasks. &lt;/p&gt;

&lt;p&gt;The library cross-validates by first randomly splitting the data into &lt;code&gt;k&lt;/code&gt; folds, and then we will have &lt;code&gt;k&lt;/code&gt; evaluation scores according to the splits where one of the &lt;code&gt;k&lt;/code&gt; subsets is held out for testing. That is called &lt;strong&gt;k-fold cross validation&lt;/strong&gt;. &lt;/p&gt;

&lt;h3&gt;
  
  
  Split the data
&lt;/h3&gt;

&lt;p&gt;Just to remind you, here's a brief recap of our three sets:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Training set: This is for training the &lt;strong&gt;parameters&lt;/strong&gt;, or training the model.&lt;/li&gt;
&lt;li&gt;Dev set: This is for training the &lt;strong&gt;hyperparameters&lt;/strong&gt;
&lt;/li&gt;
&lt;li&gt;Test set: This is for the evaluation of the model&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Now let's implement the split in our code: We will first shuffle the data and get the counts for each of our sets, depending on the splits we want (in this project we will use &lt;code&gt;8:1:1&lt;/code&gt;). Then, we build the three sets of data from the counts we've got. Here's how it looks like:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;block_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt; 

&lt;span class="c1"&gt;# Create a function for the dataset construction 
# For convenience purposes
&lt;/span&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;build_dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

    &lt;span class="c1"&gt;#print(w)
&lt;/span&gt;    &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt;
    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
      &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
      &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
      &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
      &lt;span class="c1"&gt;#print(''.join(itos[i] for i in context), '---&amp;gt;', itos[ix])
&lt;/span&gt;      &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# crop and append
&lt;/span&gt;
  &lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;Y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;

&lt;span class="c1"&gt;# Finding the counts in each sets
&lt;/span&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;random&lt;/span&gt;
&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;42&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;shuffle&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;n1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.8&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;n2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;int&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mf"&gt;0.9&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

&lt;span class="c1"&gt;# Split
&lt;/span&gt;&lt;span class="n"&gt;Xtr&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Ytr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;build_dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="n"&gt;n1&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;Xdev&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Ydev&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;build_dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;n1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;n2&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
&lt;span class="n"&gt;Xte&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Yte&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;build_dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;n2&lt;/span&gt;&lt;span class="p"&gt;:])&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That's our splits! Now we have to retrain the model, and remember to use the data in &lt;strong&gt;the training set only&lt;/strong&gt;.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Now we train on just the Xtr and Ytr
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;30000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

  &lt;span class="c1"&gt;# minibatch construct
&lt;/span&gt;  &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Xtr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;

  &lt;span class="c1"&gt;# forward pass
&lt;/span&gt;  &lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Xtr&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt; &lt;span class="c1"&gt;# (32, 3, 2)
&lt;/span&gt;  &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (32, 100)
&lt;/span&gt;  &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# (32, 27)
&lt;/span&gt;  &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Ytr&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
  &lt;span class="c1"&gt;#print(loss.item())
&lt;/span&gt;
  &lt;span class="c1"&gt;# backward pass
&lt;/span&gt;  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;
  &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

  &lt;span class="c1"&gt;# update
&lt;/span&gt;  &lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;2.3609843254089355
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And now comes the evaluation, remember to &lt;strong&gt;use the dev set&lt;/strong&gt; for this task:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Xdev&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# (32, 3, 2)
&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (32, 100)
&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# (32, 27)
&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Ydev&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor(2.4578, grad_fn=&amp;lt;NllLossBackward0&amp;gt;)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;You can see that the validation loss is a bit higher than the training loss, which indicates that the model is slightly overfitting. When it's overfitting, we often think about some &lt;strong&gt;regularization techniques&lt;/strong&gt;, actually we did one during our previous project. But let's leave this part here, in the next section we will continue to explore the ways for fine-tuning the model.&lt;/p&gt;

&lt;h2&gt;
  
  
  Fine-tuning the Model
&lt;/h2&gt;

&lt;p&gt;How can we make the model more powerful? The first thing that comes in our mind is perhaps: &lt;strong&gt;Just scale it!&lt;/strong&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  Scale the model
&lt;/h3&gt;

&lt;p&gt;Let's modify the matrices that we initialize: We will increase the neurons in our hidden layer, from &lt;code&gt;100&lt;/code&gt; to &lt;code&gt;300&lt;/code&gt;. Here's the modifications to the matrices:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Generator&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2147483647&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 
&lt;span class="n"&gt;C&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;300&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;b1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;300&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;300&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;parameters&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now let's run our model again (remember to set the &lt;code&gt;requires_grad&lt;/code&gt;), but now we keep track of the steps and the loss, just to see how the loss changes after each iteration. We might need to take the &lt;strong&gt;log&lt;/strong&gt; of the loss, for better visualization.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Now we track the step, use the log for the loss
&lt;/span&gt;&lt;span class="n"&gt;lossi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="n"&gt;stepi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;30000&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

  &lt;span class="c1"&gt;# minibatch construct
&lt;/span&gt;  &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randint&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Xtr&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,))&lt;/span&gt;

  &lt;span class="c1"&gt;# forward pass
&lt;/span&gt;  &lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Xtr&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt; 
  &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; 
  &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt;
  &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Ytr&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
  &lt;span class="c1"&gt;#print(loss.item())
&lt;/span&gt;
  &lt;span class="c1"&gt;# backward pass
&lt;/span&gt;  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt;
  &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

  &lt;span class="c1"&gt;# update
&lt;/span&gt;  &lt;span class="c1"&gt;#lr = lrs[i]
&lt;/span&gt;  &lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.1&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;lr&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;

  &lt;span class="c1"&gt;# track stats
&lt;/span&gt;
  &lt;span class="n"&gt;stepi&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
  &lt;span class="n"&gt;lossi&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;plot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;stepi&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;lossi&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F0ophudlwy41wcnm79ep2.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F0ophudlwy41wcnm79ep2.png" alt="STEP-LOSS" width="547" height="413"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;You can see how the loss is quickly driven down and then fluctuates.&lt;/p&gt;

&lt;p&gt;We should see our final loss:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;2.799872875213623
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;It's worsen, maybe that's because we put some data out. Let's evaluate this model using &lt;strong&gt;dev set&lt;/strong&gt;.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Evaluate using the dev set
&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Xdev&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# (32, 3, 2)
&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (32, 100)
&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# (32, 27)
&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Ydev&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor(2.5693, grad_fn=&amp;lt;NllLossBackward0&amp;gt;)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Not bad! Our model is performing even better with the validation set. That's a huge improvement!&lt;/p&gt;

&lt;h3&gt;
  
  
  Changing the embedding size
&lt;/h3&gt;

&lt;p&gt;We scaled the net and yet we did not get a really good result. So one thing that comes in mind is that maybe the bottleneck is not in the size of the net, but in the  &lt;strong&gt;embedding size&lt;/strong&gt;. Remember that in the first place, we squish the whole 27 characters (27-dimensional) in two dimensions, so there might be a lot of information loss. &lt;/p&gt;

&lt;p&gt;But before we change the embedding size, let's take a look at how our model learned during the training phase. Let's visualize our embedding matrix &lt;code&gt;C&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# visualize dimensions 0 and 1 of the embedding matrix C for all characters
&lt;/span&gt;&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;scatter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;200&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]):&lt;/span&gt;
    &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;],&lt;/span&gt; &lt;span class="n"&gt;ha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;center&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;va&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;center&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;white&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;grid&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;minor&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F4k1p7vjeae8wnuy7k7xa.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F4k1p7vjeae8wnuy7k7xa.png" alt="EMB MAP" width="671" height="659"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The model learns how to cluster characters, and if you notice, it even separate the vowels &lt;code&gt;a,i,u,e,o&lt;/code&gt;, which is astonishing to see. And the &lt;code&gt;g&lt;/code&gt; is far apart, may be the model thinks that &lt;code&gt;g&lt;/code&gt; is not really a common character in names? But is is still amazing how there are some meaningful things in our net, just by tweaking and tuning a whole bunch of characters.&lt;/p&gt;

&lt;p&gt;Now let's increase the embedding size to &lt;code&gt;10&lt;/code&gt;, thus changing the &lt;code&gt;C&lt;/code&gt; matrix and the &lt;code&gt;W1&lt;/code&gt;.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Generator&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2147483647&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# for reproducibility
&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;30&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;300&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;b1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;300&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;300&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;parameters&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Adjusting the dimension wisely, and rerun the code, here is our loss:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# training loss
&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;2.395942211151123
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# validation loss
&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Xdev&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# (32, 3, 2)
&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;30&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (32, 100)
&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# (32, 27)
&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Ydev&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor(2.4487, grad_fn=&amp;lt;NllLossBackward0&amp;gt;)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;Xte&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# (32, 3, 2)
&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;30&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# (32, 100)
&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="c1"&gt;# (32, 27)
&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cross_entropy&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Yte&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;loss&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor(2.4514, grad_fn=&amp;lt;NllLossBackward0&amp;gt;)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Those were better than the previous result! We did great, man.&lt;/p&gt;

&lt;h2&gt;
  
  
  Sampling from the model
&lt;/h2&gt;

&lt;p&gt;Now we should see our babies, the sampling method is just the same from the previous project, but let's see the results:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# sample from the model
&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Generator&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2147483647&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;_&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;20&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

    &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
    &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt; &lt;span class="c1"&gt;# initialize with all ...
&lt;/span&gt;    &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
      &lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;])]&lt;/span&gt; &lt;span class="c1"&gt;# (1,block_size,d)
&lt;/span&gt;      &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
      &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt;
      &lt;span class="n"&gt;probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
      &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;multinomial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_samples&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
      &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
      &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
      &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
        &lt;span class="k"&gt;break&lt;/span&gt;

    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;''&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;carpah.
quelle.
khi.
mila.
tety.
salaysa.
jazhnen.
amerahtia.
qui.
nellana.
chaiiv.
kaneel.
hham.
pein.
quinn.
sron.
taivanbi.
watell.
dearisi.
fine.
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;It's way better! Eventhough some names are nonsense, but a lot of them starts to sound name-like, and that is a huge improvement for our model. We're moving forward.&lt;/p&gt;

&lt;h2&gt;
  
  
  Summarize
&lt;/h2&gt;

&lt;p&gt;We've come so far in this journey: We learned numerous things about the MLP, we manipulate data using Tensors, then we fine-tune our model by a wide range of methods, including feeding data in mini batches, tweaking the learning rate, splitting the data, scale the model and increase the embedding size. It's a huge amount of knowledge! &lt;/p&gt;

&lt;p&gt;Thanks for reading, see you in the next chapter!&lt;/p&gt;

</description>
      <category>llm</category>
      <category>machinelearning</category>
      <category>tutorial</category>
    </item>
    <item>
      <title>LANGUAGE MODELS USING MLP (Part 1)</title>
      <dc:creator>Hưng Lê Tiến</dc:creator>
      <pubDate>Mon, 17 Nov 2025 08:31:48 +0000</pubDate>
      <link>https://dev.to/blackbidz/language-models-using-mlp-part-1-39bd</link>
      <guid>https://dev.to/blackbidz/language-models-using-mlp-part-1-39bd</guid>
      <description>&lt;p&gt;Welcome to the third part of the series! Just to remind, this blog follows the series from Andrej Karpathy on Youtube, and I'm just taking notes from his videos. Today we will explore deeper about the neural nets, and have great improvements regarding our language model.&lt;/p&gt;

&lt;p&gt;We will have two chapters for this topic, mainly because the contents are way too long and there are so much things to mention in the Andrej's video. I will try my best to provide a comprehensive view of this topic, as well as explain clearly to you guys every single concept involved.&lt;/p&gt;

&lt;p&gt;Here's the link to the video: &lt;a href="https://www.youtube.com/watch?v=TCH_1BHY58I&amp;amp;pp=ygURYnVpbGRpbmcgbWFrZW1vcmU%3D" rel="noopener noreferrer"&gt;Building makemore Part 2: MLP&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  MLP (Multi-layer Perceptron)
&lt;/h2&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6t49fitatqt2b1a6wnu1.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F6t49fitatqt2b1a6wnu1.png" alt="MLPs" width="597" height="324"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Remember the architecture you saw in the previous chapter? That was a multi-layer perceptron. Actually this term is no fancy if you're already familiar with neural nets.&lt;/p&gt;

&lt;p&gt;But in the last chapter, we built just one layer of perceptron, that explains why the model is so simple and provides some dissapointing results. In this blog, we will build a multilayer perceptron (actually we will just build 2 layers), and we shall see how powerful the model becomes when having more hidden layers. &lt;/p&gt;

&lt;h2&gt;
  
  
  A Revolutionizing Approach
&lt;/h2&gt;

&lt;p&gt;Before we jump right in the coding things and blindly adjusting the parameters to find the best result, I think we really need to take a step back and address the main problems of our past approach and why the model is so bad at generating names even when the loss is fairly optimized. So take a sip, and we will go through some important insights for our models:&lt;/p&gt;

&lt;h3&gt;
  
  
  Why Don't We Scale the Model?
&lt;/h3&gt;

&lt;p&gt;This is the most naive question that we could come up with when we want to make a model more powerful. Maybe we would love to have a 4-gram or 5-gram language model rather than just a bigram, so that the model can take into account a lot of preceding characters hence arriving at better results. &lt;/p&gt;

&lt;p&gt;But imagine how the table of counts would look like? Actually it will scale &lt;strong&gt;exponentially&lt;/strong&gt;, with the base of 27. In the case of your lovely 4-gram or 5-gram, that would be 27^4 and 27^5, which is quite intimidating to deal with.&lt;/p&gt;

&lt;p&gt;So scaling is not really good, but are there any clever approaches? Well yes, numerous research had been carried out, and the most ground-breaking one, I would say, is the one from Bengio and his friends.&lt;/p&gt;

&lt;h3&gt;
  
  
  A Brief Overview of &lt;a href="https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf" rel="noopener noreferrer"&gt;Bengio's Paper&lt;/a&gt;
&lt;/h3&gt;

&lt;p&gt;Reading paper can be boring, and intimidating at times. But those are just some kinds of approach to solve some problems that get on the nerves of scientists, so we don't need to really know the maths or some experiments to understand a paper, actually, when reading a paper, we should focus more on &lt;em&gt;What problems are being addressed?&lt;/em&gt; and &lt;em&gt;What is the intuition behind their solution for those problems&lt;/em&gt;. Of course, it would be great if we can further understand the implementation and the proves through long texts of formulas and theorems. For now, let's talk about just the intuition.&lt;/p&gt;

&lt;p&gt;The paper first pointed out that there are two main problems with the current language model, the n-gram, and those problems are, to your surprised, mentioned in our previous chapters when we build the model:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;The model does not take into account &lt;strong&gt;context&lt;/strong&gt; farther than 1 or 2 words. Well, we talked about that.&lt;/li&gt;
&lt;li&gt;The model does not take into account the &lt;strong&gt;similarity&lt;/strong&gt; between words.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Hold up. You may question that the similarity, or the synonyms are not really relevant to our task of predicting the next word. But actually it is of &lt;em&gt;paramount importance&lt;/em&gt; for our main task, specifically in improving the model's ability to &lt;em&gt;generalize&lt;/em&gt; .&lt;/p&gt;

&lt;p&gt;Imagine when the model encounter a sequence "A dog was running in the ..." and it has no previous data about the whole sequence in the training set, what will it do? Suppose that it has a training instance which is "The cat is walking in the bedroom", and it is trained a huge text corpus and has the ability to know that "cat" is similar to "dog", and "walking" is just like "running", "A" and "the" are virtually the same,... With that, maybe the model will come up with something like "room" or "living room", which share the similarities with "bedroom" from the training set. And that was a great prediction! The model performs extremely well even though it encounters data that it has never seen before. In other words, we say that the model has &lt;strong&gt;generalized&lt;/strong&gt; well with the test set, just by using the &lt;strong&gt;similarities&lt;/strong&gt; between words.&lt;/p&gt;

&lt;p&gt;So how can we create a model which can perform that magical task? Now, we shall see the most ground-breaking part of the paper.&lt;/p&gt;

&lt;h3&gt;
  
  
  Feature vectors - Embeddings
&lt;/h3&gt;

&lt;p&gt;We starts by associating each words in the sequence with a vector, usually in a lower dimension than the vocabulary size (that's actually a clever way to combat the &lt;strong&gt;Curse of Dimensionality&lt;/strong&gt;), and we will work with the vector from now on.&lt;/p&gt;

&lt;p&gt;In our previous chapter, we did this kind of conversion by the &lt;em&gt;one-hot encoding&lt;/em&gt;, which turns the characters into a one-hot vector to feed in the neural net. But this one-hot vector captures nothing but the &lt;em&gt;position&lt;/em&gt; of the word in the alphabet, which is of no avail to our task of predicting.&lt;/p&gt;

&lt;p&gt;So now imagine their approach as a smart way to encode our words. They convert words to vectors just like us, but first they store the words in a lower-dimensional space, and second, they arrange the vectors so that they capture the &lt;strong&gt;similarities&lt;/strong&gt;. What I mean here when saying "capture the similarity" is that in the vector space, the vectors represent words with similar meanings, or share the same &lt;strong&gt;semantic&lt;/strong&gt; will end up &lt;strong&gt;closer&lt;/strong&gt; to each other, and that is the main point for all of this. We will have huge clusters of synonyms, not only that but also some cool tricks playing with the meaning of words.&lt;/p&gt;

&lt;p&gt;This "feature vector" trick is still applied in our modern world, but it have been improved significantly, now it is prevalent with the name of &lt;strong&gt;&lt;em&gt;Embeddings&lt;/em&gt;&lt;/strong&gt;, so we will call this method as &lt;em&gt;embeddings&lt;/em&gt; from now. Embeddings itself involves some complicated mechanisms, but from the very root, it is just a modern version of the feature vectors.&lt;/p&gt;

&lt;h3&gt;
  
  
  Explore Bengio's model's architecture
&lt;/h3&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fd0t7u2vfm4x0vky8ynpm.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fd0t7u2vfm4x0vky8ynpm.png" alt="Bengio's architecture" width="800" height="682"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;First, in the &lt;strong&gt;Input layer&lt;/strong&gt;, we still have the inputs as the index for the words, (or in our case, the characters), and this is just the lookup table we created with the mapping from a-z to 1-26. So no big-brain stuff here.&lt;/p&gt;

&lt;p&gt;In the next layer, interesting things happen, I will call this the &lt;strong&gt;Embedding layer&lt;/strong&gt;. This is different from our previous model, and it's the key point that brings about the significant improvements later on. We have a matrix &lt;code&gt;C&lt;/code&gt; for our embedding process, and the parameters in the matrix is tweaked and tuned during the learning process, as well as it is shared across all words. Later on, in our model, we will explore some interesting patterns that our neural net learned during the training phase by looking at this matrix.  &lt;/p&gt;

&lt;p&gt;A notable thing to mention here is the &lt;em&gt;embedding size&lt;/em&gt;, we will decide what dimension would we "squish" our words into. In the paper, they embedded a vocabulary of total 17,000 words into just 30-dimensional vector space. This is called &lt;strong&gt;Dimensionality Reduction&lt;/strong&gt;, and there are numerous other methods to implement this. But you should also note that when we transform a high-dimensional vector space to a lower-dimensional ones, there would definitely be some information loss. So there's a tradeoff and we should choose carefully.&lt;/p&gt;

&lt;p&gt;The next layer is the &lt;strong&gt;Hidden layer&lt;/strong&gt; of our net. This time we can choose the size of the layer by any number we want. Note that in the layer they used an &lt;strong&gt;activation function&lt;/strong&gt; called &lt;strong&gt;tanh&lt;/strong&gt;, a classic one, we will talk about the whole family of the activation function perhaps in the next blog. &lt;/p&gt;

&lt;p&gt;The last layer, is the &lt;strong&gt;Output layer&lt;/strong&gt;, and you can see the note &lt;em&gt;"most computation here"&lt;/em&gt;. It is an expensive layer, as we have to calculate the probability distribution for &lt;strong&gt;&lt;em&gt;every word in the vocabulary&lt;/em&gt;&lt;/strong&gt;, that contributes to the total of 17,000 logits, and then we apply the &lt;em&gt;Softmax&lt;/em&gt; for all of those to get the final result. &lt;/p&gt;

&lt;p&gt;So that is the barebone of their models and we will start to rebuilt it in our small project. Let's begin!&lt;/p&gt;

&lt;h2&gt;
  
  
  Load the data &amp;amp; Libraries
&lt;/h2&gt;

&lt;p&gt;Moving on to the third chapter, we might need some improvement in the structure of our project. For convenient purpose, we should import all of the libraries in the very beginning, and the load the dataset&lt;/p&gt;

&lt;h3&gt;
  
  
  Importing libraries
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn.functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt; &lt;span class="c1"&gt;# for making figures
&lt;/span&gt;&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="n"&gt;matplotlib&lt;/span&gt; &lt;span class="n"&gt;inline&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Loading the datset
&lt;/h3&gt;



&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# download the names.txt file from github
&lt;/span&gt;&lt;span class="err"&gt;!&lt;/span&gt;&lt;span class="n"&gt;wget&lt;/span&gt; &lt;span class="n"&gt;https&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="o"&gt;//&lt;/span&gt;&lt;span class="n"&gt;raw&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;githubusercontent&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;com&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;karpathy&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;makemore&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;master&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;names&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;txt&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;words&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;names.txt&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;read&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;splitlines&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="mi"&gt;8&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;So now we have our data!&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']


&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  Prepare the Data
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Indexing
&lt;/h3&gt;

&lt;p&gt;First thing first, we need to reuse the mapping that we created for the characters to index them and feed into our model.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# build the vocabulary of characters and mappings to/from integers
&lt;/span&gt;&lt;span class="n"&gt;chars&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;sorted&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;set&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;''&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;))))&lt;/span&gt;
&lt;span class="n"&gt;stoi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chars&lt;/span&gt;&lt;span class="p"&gt;)}&lt;/span&gt;
&lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

&lt;span class="c1"&gt;# Also remember to map backward
&lt;/span&gt;&lt;span class="n"&gt;itos&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()}&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Build the dataset
&lt;/h3&gt;

&lt;p&gt;There is an important note here: We will use a &lt;strong&gt;Trigram language model&lt;/strong&gt;, not the bigram ones, so we're scaling the &lt;em&gt;context window&lt;/em&gt; a little bit.&lt;/p&gt;

&lt;p&gt;With that in mind, then we will need some modifications when preparing our data. We have a new variable &lt;code&gt;block_size&lt;/code&gt;, which is basically the number of characters in the context window. We set it to 3 for our model.&lt;/p&gt;

&lt;p&gt;We will also need a technique to capture 3 characters at a time for each iteration rather than just one. Actually it is fairly simple: We first have a window of size 3 filled with 0, and then we &lt;strong&gt;slide the window&lt;/strong&gt; across the training data, storing it in our tensor &lt;code&gt;X&lt;/code&gt; and the corresponding labels in the tensor &lt;code&gt;Y&lt;/code&gt;. &lt;/p&gt;

&lt;p&gt;Let's look at how we can implement it in python, it is just cropping and appending new element in a list:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# build the dataset
&lt;/span&gt;&lt;span class="n"&gt;block_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt; &lt;span class="c1"&gt;# context length: how many characters do we take to predict the next one?
&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;

&lt;span class="c1"&gt;# We will just deal with 5 names for now
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
  &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;block_size&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
    &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;''&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;---&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="n"&gt;context&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;context&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="c1"&gt;# crop and append
&lt;/span&gt;
&lt;span class="n"&gt;X&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;Y&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Here's our data:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;... ---&amp;gt; e
..e ---&amp;gt; m
.em ---&amp;gt; m
emm ---&amp;gt; a
mma ---&amp;gt; .
... ---&amp;gt; o
..o ---&amp;gt; l
.ol ---&amp;gt; i
oli ---&amp;gt; v
liv ---&amp;gt; i
ivi ---&amp;gt; a
via ---&amp;gt; .
...
torch.Size([32, 3]) torch.Size([32])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We have 32 windows, each of size 3, that is just enough to move to the next stage.&lt;/p&gt;

&lt;h2&gt;
  
  
  Embeddings &amp;amp; Data Manipulation with Pytorch
&lt;/h2&gt;

&lt;p&gt;We will take us somewhere off the track in this section, as we will not dive into the projects or fine-tuning, but we will focus on the part of &lt;strong&gt;Data Manipulation&lt;/strong&gt;, which is really interesting. &lt;/p&gt;

&lt;p&gt;Pytorch is often the go-to place for Neural Nets and even more complicated models as it offers some really convenient operations and better use of memory when storing data. In our previous project, we just explore a bit of the magic that Pytorch offer, namely the &lt;strong&gt;Broadcasting rule&lt;/strong&gt; and &lt;strong&gt;The backward() function&lt;/strong&gt;. Those are just the surface, there are a whole world of convenience in Pytorch, and some interesting manipulation that I think is worth to learn, given that we will have to do a lot of stuff to our data in our projects.&lt;/p&gt;

&lt;p&gt;So let's walkthrough some while we're building our model:&lt;/p&gt;

&lt;h3&gt;
  
  
  Initializing the embedding vector
&lt;/h3&gt;

&lt;p&gt;The most important thing here is the &lt;em&gt;size&lt;/em&gt; of this matrix, we need to determine the dimension that we want to project our data on. In this case, we will choose 2.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Initilize randomly
&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now we need to have a good understanding of the size of the matrices that we create, and thus understand the underlying process. This matrix &lt;code&gt;C&lt;/code&gt; is of the size &lt;code&gt;27x2&lt;/code&gt;, which means that it stores 27 components, each component is a &lt;em&gt;vector&lt;/em&gt; that embeds our character. The reason why the number &lt;code&gt;27&lt;/code&gt; appears is that we have 27 characters in total, and the number &lt;code&gt;2&lt;/code&gt; denotes the corresponding vector that the character represents. &lt;/p&gt;

&lt;p&gt;What if we want to embed a character using this matrix? There are two ways for doing this:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Using the indexing
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;ul&gt;
&lt;li&gt;Multiply the C matrix with an one-hot vector of size 27
&lt;/li&gt;
&lt;/ul&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# One-hot -&amp;gt; Remember to convert to float
&lt;/span&gt;&lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;one_hot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;These both return an 2-dim vector:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([0.5262, 1.0655])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;For convenience purposes, we will stick with the former, and actually, Pytorch has some great indexing techniques that offer even more flexibility.&lt;/p&gt;

&lt;h3&gt;
  
  
  Indexing with Pytorch
&lt;/h3&gt;

&lt;p&gt;We can index using lists, which means that we can pass in a list of integers representing the index we want and it will return the sequence of 2-dimensional embedding vectors.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Index using List
&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[[&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;7&lt;/span&gt;&lt;span class="p"&gt;]]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([[ 0.5262,  1.0655],
        [-2.2277, -0.5293],
        [-0.6665, -1.0212]])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The list can contain duplicates:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Can be a tensor of int, we can repeat
&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;7&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;7&lt;/span&gt;&lt;span class="p"&gt;])]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([[ 0.5262,  1.0655],
        [-2.2277, -0.5293],
        [-0.6665, -1.0212],
        [-0.6665, -1.0212],
        [-0.6665, -1.0212],
        [-0.6665, -1.0212],
        [-0.6665, -1.0212]])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;What if we pass in a whole matrix? Pytorch can handle that too! It will iterate through the whole matrix and assign each entries with the corresponding vector. So when indexing, it should return a new matrix with an additional dimension at the end, which is equivalent to the embedding size. Take a moment to think about it, and then look at the implementation to understand what I mean:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Indexing with a multidimensionl tensor
&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;torch.Size([32, 3, 2])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Let's take a look at what it is actually doing inside the matrix:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;&amp;gt;&amp;gt;&amp;gt;X[:5]
tensor([[ 0,  0,  0],
        [ 0,  0,  5],
        [ 0,  5, 13],
        [ 5, 13, 13],
        [13, 13,  1]])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;&amp;gt;&amp;gt;&amp;gt;C[X][:5]
tensor([[[-2.3175, -0.5157],
         [-2.3175, -0.5157],
         [-2.3175, -0.5157]],

        [[-2.3175, -0.5157],
         [-2.3175, -0.5157],
         [ 0.6433,  0.8121]],

        [[-2.3175, -0.5157],
         [ 0.6433,  0.8121],
         [ 1.0138, -0.7526]],

        [[ 0.6433,  0.8121],
         [ 1.0138, -0.7526],
         [ 1.0138, -0.7526]],

        [[ 1.0138, -0.7526],
         [ 1.0138, -0.7526],
         [ 1.5965, -0.8861]]])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;So in the X we have a bunch of 3 dimensional array, and in C we have a bunch of 3x2 matrix, it simply takes in an additional dimension to store the embedding of each elements! I also get confused in this part so don't be shy to take a moment to think about it.&lt;/p&gt;

&lt;p&gt;And we can get the location of a specific vector by indexing just like we're working with multidimensional array in Python:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# We get the embedding for the 13th window, at the third character
&lt;/span&gt;&lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="mi"&gt;13&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([-0.0371,  0.8457])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now we should name the embedding of X:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;emb&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;C&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;X&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;shape&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  Creating the first layer
&lt;/h2&gt;

&lt;p&gt;As you can see from the architecture, a pack of 3 characters are inputted into the model, but we embedded those into 2 dimensional vectors, so that contributes to a total of 3 x 2 = 6 input values. &lt;/p&gt;

&lt;p&gt;Okay so our layer plugs in 6 values of input, and what is the output, or more precisely, how many &lt;em&gt;neurons&lt;/em&gt; does it have? This is totally up to us, and I will choose 100. And we're ready to go!&lt;/p&gt;

&lt;p&gt;You may question why do we need to think about all of this in the first place. Well, it is crucial to think about the matrix that we want to create, and how can we create it, there are lots of insights when we look at the &lt;em&gt;size&lt;/em&gt; of the matrix. Moreover, when we perform matrix multiplication, it is important to keep track of the dimensions, as we cannot perform that operation if the dimensions don't match. Also, I want to clarify each step that we're doing, so that we won't get lost by some randomly-popped-up numbers, it's a way to assure that we're on our track.&lt;/p&gt;

&lt;h3&gt;
  
  
  Initializing the Weights and Bias
&lt;/h3&gt;

&lt;p&gt;In the previous chapter, we dealt with the weights only, but in practice, we need to additional term called &lt;strong&gt;Bias&lt;/strong&gt;, so that we will have the classic formula of "wx + b" that appears in virtually every ML/DL books.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Creating the first layer
&lt;/span&gt;
&lt;span class="c1"&gt;# Number of input : 3x2 because we have 2 dim embedd and 3 chars
# Number of neurons: 100 (totally up to us)
&lt;/span&gt;&lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;b1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Fitting the Dimension
&lt;/h3&gt;

&lt;p&gt;So now we need to feed our data into the first layer, right? We will do that by matrix multiplication, but wait? There's something wrong.&lt;/p&gt;

&lt;p&gt;We can't do matrix multiplication because the dimension don't fit! Note that the embedding matrix is of the size &lt;code&gt;[32,3,2]&lt;/code&gt; while our W1 matrix if of the size &lt;code&gt;[6,100]&lt;/code&gt;. Those are not even the matrices of the same type. So what will we do? We would love to have our embedding matrix to have the size &lt;code&gt;[32,6]&lt;/code&gt;, specifically we want to find a way to "merge" all of the embedding vectors in the 3-char window into one. In other words, the number of the trigram remains, but for each vector in it, we want to convert a sequence of 3 2-d vector, to a sequence of 6 values in order to feed in our first hidden layer. And guess what? Pytorch saves the day again!&lt;/p&gt;

&lt;p&gt;There is a function in the Pytorch library called &lt;code&gt;concat&lt;/code&gt; and it does the merging stuff that we discussed above. First we need to take the sequence of embedding vectors for &lt;em&gt;each character&lt;/em&gt; in the window, then apply the &lt;code&gt;concat&lt;/code&gt; function to it. The code goes like this:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Take the three and concat, in the second dim
&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,:],&lt;/span&gt; &lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,:],&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;[:,&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,:]],&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Let's see the shape of this vector:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;torch.Size([32, 6])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Exactly what we want! But hard-coding the sequence of embedding vectors to pass in the function is not really good, and we want to optimize that too. So another function comes in the way which does exactly that task, it is called &lt;code&gt;unbind&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Unbind take the lists, return tuples of tensors
&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;unbind&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Just spend a few seconds to contemplate how our library can help us to implement a labor-intensive task in one line of code. That's amazing.&lt;/p&gt;

&lt;h3&gt;
  
  
  Internals of Tensors
&lt;/h3&gt;

&lt;p&gt;Actually there is even a &lt;em&gt;more convenient&lt;/em&gt; way to implement the task above, which doesn't require any fancy functions. We will first introduce to you about the &lt;strong&gt;Internals of Tensors&lt;/strong&gt;, but I would recommend you to read the blog of ezyang for a more comprehensive understanding. &lt;a href="https://blog.ezyang.com/2019/05/pytorch-internals/" rel="noopener noreferrer"&gt;This is the blog.&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;Briefly speaking, the tensor in Pytorch has an interesting way of storing data. Specifically, it stores every values in the matrix in a one-dimensional array, &lt;em&gt;irrespective of the size&lt;/em&gt;. So you can imagine that everything got flatten out into just one single long series. We can get access to this storage by the function &lt;code&gt;storage&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;storage&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt; -0.6637181043624878
 0.31748151779174805
 -0.6637181043624878
 0.31748151779174805
 -0.6637181043624878
 0.31748151779174805
 ...
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;But what is the point for all of this? It turns out, that there is this specific method called &lt;code&gt;view()&lt;/code&gt; that can return to us a matrix of any size, using the data from our matrix. It's a much more efficient way since it's just representing the data differently, it doesn't create new tensors to work with. A note here is that your new matrix can be of any size that you want &lt;em&gt;as long as the dimensions match with the original&lt;/em&gt;. Let's implement this black magic:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([[-1.3302, -1.3333, -1.3302, -1.3333, -1.3302, -1.3333],
        [-1.3302, -1.3333, -1.3302, -1.3333, -0.1033, -1.4972],
        [-1.3302, -1.3333, -0.1033, -1.4972, -0.9485,  0.6885],
        [-0.1033, -1.4972, -0.9485,  0.6885, -0.9485,  0.6885],
        ......
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We can also use &lt;code&gt;emb.view(-1,6)&lt;/code&gt;, it generates the same result and it also eliminates the need to hard-code the number 32 (that number won't be used later when we consider the whole dataset). But remember, we can't have matrices like &lt;code&gt;emb.view(5,6)&lt;/code&gt;, we need the multiples of the dimension equal with that of the original matrix.&lt;/p&gt;

&lt;p&gt;Now we need to name the new matrix, and also get the &lt;strong&gt;tanh&lt;/strong&gt; of the values, just like in the architecture.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Get the tanh
&lt;/span&gt;&lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tanh&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;emb&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;view&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W1&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;If you are thinking about the dimensions here, you may question why the &lt;code&gt;b1&lt;/code&gt; matrix can be added while it's just a matrix of size 100, while others are two-dimensional. This is actually a valid operation (at least in Pytorch), and it involves the thing that we already know: &lt;strong&gt;Broadcasting&lt;/strong&gt;. The &lt;code&gt;b1&lt;/code&gt; vector is later broadcasted into a row vector, so the value in each row are added to the corresponding row in the previous matrix, and its exactly what we want! &lt;/p&gt;

&lt;p&gt;And that is everything about the magic of the Pytorch library for manipulating data. I hope you will get a sense of how amazing this library is, and you can try some operations yourself, they are all at their webpage &lt;a href="https://docs.pytorch.org/docs/stable/torch.html" rel="noopener noreferrer"&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;h2&gt;
  
  
  Create the Final Layer
&lt;/h2&gt;

&lt;p&gt;Think about the dimension again, what should the &lt;code&gt;W2&lt;/code&gt; vector be like? First, it should output the probability for 27 characters, so there should be a 27 in the size. Moreover, it has to match with the previous matrix, which is of the size &lt;code&gt;6x100&lt;/code&gt;. Hence, we have the size of the final weight matrix: &lt;code&gt;100x27&lt;/code&gt;, the bias is, surely, an array of 27 elements.&lt;/p&gt;

&lt;p&gt;Let's implement that in our code:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;b2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And now, we're ready to calculate our probability!&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;h&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W2&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;b2&lt;/span&gt;

&lt;span class="c1"&gt;# Applying softmax
&lt;/span&gt;&lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;prob&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;keepdims&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And finally, the negative log likelihood:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;32&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;&lt;span class="n"&gt;Y&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The result is, maybe some drum roll for this moment:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor(14.3920)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;That is TERRIFYING. But we haven't done any optimization yet, zero zip nah. And that, my friend, is gonna be the story of the next part. In this part, we've already equipped ourselves with a whole bunch of knowledge, from the revolutionizing approach of Bengio, generating a trigram dataset, to great data manipulation techniques in Pytorch. Those are really enough for today, and congrats yourself for reaching this far.&lt;/p&gt;

&lt;p&gt;Thanks for reading, and see you in the next blog!&lt;/p&gt;

</description>
      <category>deeplearning</category>
      <category>llm</category>
      <category>tutorial</category>
      <category>machinelearning</category>
    </item>
    <item>
      <title>BIGRAM LANGUAGE MODELS USING A NEURAL NET</title>
      <dc:creator>Hưng Lê Tiến</dc:creator>
      <pubDate>Fri, 14 Nov 2025 14:13:12 +0000</pubDate>
      <link>https://dev.to/blackbidz/bigram-language-models-using-a-neural-net-452f</link>
      <guid>https://dev.to/blackbidz/bigram-language-models-using-a-neural-net-452f</guid>
      <description>&lt;p&gt;Welcome to the second chapter in the series, today I will present a different approach to our task of predicting names, based on the video from Andrej Karpathy. In this blog, we will do more than the mere task of counting and normalizing the sum to get the probability, which is not really machine-learning-alike, we will take a step further, a big step, to jump right into Neural Network.&lt;/p&gt;

&lt;h2&gt;
  
  
  A Brief Introduction for Neural Nets
&lt;/h2&gt;

&lt;p&gt;I will give a short overview of this wonderful architecture, just some basics to make you feel well when reading to this blog. Further details will be discussed in later blog, like the MLP, or the CNN, LSTM things like that. So let's begin, shall we?&lt;/p&gt;

&lt;p&gt;Neural network is a kind of architecture that some brilliant people created in order to replicate the neural system of human. To be more precise, we're talking about the &lt;em&gt;Artificial Neural Network (ANN)&lt;/em&gt;, and it consists of a bunch of nodes connecting with each other as you can see here:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9keqmygw8oh2ucbnvfhm.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2F9keqmygw8oh2ucbnvfhm.png" alt="Neural Network" width="800" height="400"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;You just need to remember that:&lt;/p&gt;

&lt;ol&gt;
&lt;li&gt;Each node in the neural net is called a &lt;em&gt;neuron&lt;/em&gt;, and from the image we can see that those are nicely separated into parts that are the &lt;em&gt;layers&lt;/em&gt;. The layers in the neural net are often the Input layer, the Output layer, and some Hidden layers in the middle (but in this blog today we will just need to construct one hidden layer, so it's not really hidden)&lt;/li&gt;
&lt;li&gt;"What are those neurons doing and why do we need to pass information through layers of a bunch of neurons?" might be the question in your mind. Well this is hard to tell, I would say that the neurons take the inputs and apply a function to it, like wx + b, so we have a  bunch of functions (sometimes those can be non-linear, we will talk about that later) and the values of w,b that we assign for each function is what we would call the &lt;em&gt;parameters&lt;/em&gt;. A complex neural net has thousands of parameters, and those are changeable. This is perhaps the most important thing, &lt;em&gt;we can change the parameters in the neural net&lt;/em&gt;.&lt;/li&gt;
&lt;li&gt;Our goal is to &lt;em&gt;tweak and tune&lt;/em&gt; the parameters in order to get our desired result, and we use a technique called &lt;em&gt;backpropagation&lt;/em&gt; to do so. Backpropagation is just a method we use when we want to minimize the loss (the difference) between our prediction and the actual result, where we work backward from the prediction all the way to the input, probably tweaking the parameters along the way. But to get the prediction, first we need to plug in the input and pass it through layers of neurons, that's called a &lt;em&gt;forward pass&lt;/em&gt;.&lt;/li&gt;
&lt;li&gt;But how can we really "get the desired result" from a neural net? What do we expect from the output layer? It really depends on the task, but normally, in the classification problem or in this specific makemore project, the output layer would produce a kind of &lt;em&gt;probability distribution&lt;/em&gt; for the possible values. We will discuss about that later when we delve into the code.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;Okay if you understand all of the things above, then you are ready for the next part! (actually I think my explanation is not that good)&lt;/p&gt;

&lt;h2&gt;
  
  
  Building Makemore with a Neural Net
&lt;/h2&gt;

&lt;p&gt;When dealing with neural nets, things get complicated. We just can't "tell" the net to do the things that we want, I mean there is a huge language barrier, right? So we need to come up with a task, a really specific task to plug to the machine and when the machine does the task, it generates the result we want. &lt;/p&gt;

&lt;p&gt;Normally the task in ML or DL involves &lt;em&gt;minimizing a function&lt;/em&gt;, okay so we need a function to minimize. So where? Remember the negative log-likelihood that we talked about in the previous chapter? There it is! That also explains why we use the &lt;em&gt;negative&lt;/em&gt; log-likelihood, not the log-likelihood itself, because our objective is to &lt;em&gt;maximize the likelihood&lt;/em&gt;, which in turn equals to &lt;em&gt;minimize the negative log-likelihood&lt;/em&gt;.&lt;/p&gt;

&lt;p&gt;During the previous chapter, we solved the problem in a rather &lt;em&gt;explicit way&lt;/em&gt;, which is, we are just using mere counts or explicit data in order to get the result. In this chapter and numerous chapters later on, we will solve the task in an &lt;em&gt;implicit way&lt;/em&gt;, which is like using a secret mechanism, a kind of black magic, to arrive at the final answer. &lt;/p&gt;

&lt;p&gt;We need to restructure our data a bit, as we're putting it into a neural network. Now, remember that this is still a bigram language model, we still have one character as an input and we output the &lt;em&gt;probability&lt;/em&gt; of the next character. We also have to signal the cases where the model is correct, which means they assign a high probability to the correct character, so those correct characters should be our &lt;em&gt;labels&lt;/em&gt;, and our goal is to maximize the probability for the &lt;em&gt;labels&lt;/em&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  Create the training set
&lt;/h3&gt;

&lt;p&gt;Now we shall begin:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Create the training set
# Inputs and the labels
&lt;/span&gt;&lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;ys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[],[]&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
  &lt;span class="n"&gt;chs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]):&lt;/span&gt;
    &lt;span class="n"&gt;ix1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;ix2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ix1&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;ys&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ix2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Create tensors
&lt;/span&gt;&lt;span class="n"&gt;xs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;ys&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ys&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;So we have our inputs and labels, the characters are encoded in the form of integers:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([ 0,  5, 13, 13,  1])
tensor([ 5, 13, 13,  1,  0])

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  One-hot encoding
&lt;/h3&gt;

&lt;p&gt;But is this good data to feed in our neural net? The answer is NO. Briefly speaking, our boy doesn't like strings or integers, it prefers &lt;strong&gt;vectors&lt;/strong&gt;. And if we normalize the vector then it is even better!&lt;/p&gt;

&lt;p&gt;Just like when we have to turn text into integers to feed in our bigram language model, in this case, we need to find a way to encode integers into some sorts of vectors to feed in the neural net. A convenient way to do that is &lt;strong&gt;One-hot encoding&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;One-hot encoding is simple, it makes an array full of zeros and turn the i-th element to 1 for every integer with the value of i. You can take a look at it in the code below&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch.nn.functional&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;

&lt;span class="n"&gt;xenc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;one_hot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;xenc&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0]])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We also need to cast the entries of our vector to the &lt;strong&gt;floating point number&lt;/strong&gt;, which is convenient for computing inside a neural net.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;xenc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;one_hot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;So the ingredients are fully prepared, let's cook!&lt;/p&gt;

&lt;h2&gt;
  
  
  Making the net
&lt;/h2&gt;

&lt;p&gt;The parameters can be stored in a huge matrix and the input vectors are parts of a huge input matrix too. For convenience when computing, we are putting things in &lt;strong&gt;matrices&lt;/strong&gt;, and if you scrutinize the behavior of the neural net, or even MLP or Attention, you can see that the operations are just a bunch of &lt;strong&gt;matrix multiplication&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;Let's create our own matrices of the parameter w, actually this man has a name, it is the &lt;strong&gt;weight&lt;/strong&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  Initialize the weights
&lt;/h3&gt;

&lt;p&gt;There are numerous ways initialize the weights, and even some strategic ones. But oftentimes we draw out the weight randomly from a normal distribution.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Initialize the weights
&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Generator&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2147483647&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;W&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;randn&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;&lt;span class="n"&gt;generator&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="c1"&gt;# That randn function draw from a normal distribution
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Here's a piece of our weight matrix:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([[ 1.5674, -0.2373, -0.0274, -1.1008,  0.2859, -0.0296, -1.5471,  0.6049,
          0.0791,  0.9046, -0.4713,  0.7868, -0.3284, -0.4330,  1.3729,  2.9334,
          1.5618, -1.6261,  0.6772, -0.8404,  0.9849, -0.1484, -1.4795,  0.4483,
         -0.0707,  2.4968,  2.4448],
        [-0.6701, -1.2199,  0.3031, -1.0725,  0.7276,  0.0511,  1.3095, -0.8022,
         -0.8504, -1.8068,  1.2523, -1.2256,  1.2165, -0.9648, -0.2321, -0.3476,
          0.3324, -1.3263,  1.1224,  0.5964,  0.4585,  0.0540, -1.7400,  0.1156,
          0.8032,  0.5411, -1.1646],
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And then we will have our inputs modified by the weights, using matrix multiplication&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xenc&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;





&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([[-0.6701, -1.2199,  0.3031, -1.0725,  0.7276,  0.0511,  1.3095, -0.8022,
         -0.8504, -1.8068,  1.2523, -1.2256,  1.2165, -0.9648, -0.2321, -0.3476,
          0.3324, -1.3263,  1.1224,  0.5964,  0.4585,  0.0540, -1.7400,  0.1156,
          0.8032,  0.5411, -1.1646]])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Adding the non-linearity - Softmax
&lt;/h3&gt;

&lt;p&gt;We won't use just that for our neural net's layers ! The problem here is that: when you keep multiply the input by a weight w, and you do that consecutively, then at the end of the day, after hundreds of multiplication, you still get a mere multiple of x. Applying the linear function several times still makes things linear, which makes our model rather simple and unable to capture any complicated patterns.&lt;/p&gt;

&lt;p&gt;Hence, we should add some non-linearity to the process. We will use a function called &lt;strong&gt;&lt;em&gt;Softmax&lt;/em&gt;&lt;/strong&gt; in this blog. &lt;/p&gt;

&lt;p&gt;So what is Softmax? It is created to generate the probability distribution that we want. Remember that when we do matrix multiplication, the output of the neural net can be virtually anything, while we just want it to output some positive integers that sum up to 1. Softmax does the trick, it takes the exponential of the output and normalize it by the sum of all exponentials. There are some other interesting details about this function, but I will do that in another blog, let see the formula of the function:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fqy3qra1ddg0d20rjo7sd.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fqy3qra1ddg0d20rjo7sd.png" alt="Softmax" width="640" height="390"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Now we create Softmax on our own:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt;  &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xenc&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# log-counts
&lt;/span&gt;&lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; 
&lt;span class="n"&gt;probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;keepdims&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;probs&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And there it is, nice-looking positive numbers that sum up to 1:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([[0.0607, 0.0100, 0.0123, 0.0042, 0.0169, 0.0126, 0.0027, 0.0232, 0.0137,
         0.0313, 0.0079, 0.0278, 0.0091, 0.0082, 0.0500, 0.2372, 0.0604, 0.0025,
         0.0250, 0.0055, 0.0339, 0.0109, 0.0029, 0.0198, 0.0118, 0.1535, 0.1458],
        [0.0290, 0.0796, 0.0248, 0.0521, 0.1983, 0.0289, 0.0094, 0.0335, 0.0097,
         0.0301, 0.0701, 0.0229, 0.0115, 0.0184, 0.0108, 0.0315, 0.0291, 0.0045,
         0.0915, 0.0215, 0.0486, 0.0300, 0.0501, 0.0027, 0.0118, 0.0022, 0.0472],
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h2&gt;
  
  
  Tweak and Tune the Parameters
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Gradient
&lt;/h3&gt;

&lt;p&gt;Before we dive into the main part of our project today, we need to take a step back to think about a very important concept of Machine Learning and Deep Learning in general: The Gradient or Gradient Descent.&lt;/p&gt;

&lt;p&gt;You may heard about the Gradient in your Calculus class, for a multivariable function, the gradient shows vector that points to the direction of the steepest ascent. Well, to put it simply, imagine you are on the hill, that Gradient friend is the guy who always points you to the direction that is the steepest and tells you go climb up. When doing ML stuff, we prefer to climb down, because we're minimizing not maximizing, that that's were Gradient Descent comes in. &lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;Gradient descent is an optimization algorithm that uses the gradient to iteratively find the minimum of a function by taking steps in the direction opposite to the gradient.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;This is the formula for Gradient Descent:&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ftbki5ar483apkwuajku9.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Ftbki5ar483apkwuajku9.png" alt="Gradient Descent" width="800" height="534"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;To be honest, this is just an over-simplified view of the matter, Gradient Descent is much more complicated and has numerous variants. But we understand the main idea: It's a way to get the loss down as much as possible.&lt;/p&gt;

&lt;h3&gt;
  
  
  Backpropagation
&lt;/h3&gt;

&lt;p&gt;We might need a whole blog to talk about the beauty of this algorithm, this is the thing that revolutionized the world of neural nets. This lies in the core of almost every models you see in the world today.&lt;/p&gt;

&lt;p&gt;What is it? When calculating the the gradient to minimize the loss, actually we cannot calculate it directly. The &lt;em&gt;forward pass&lt;/em&gt;, or the computing process to get to the prediction, involves numerous consecutive computations that is really difficult for us to connect to one main function to get the gradient. &lt;/p&gt;

&lt;p&gt;So a brilliant way to solve that problem was proposed: we adjust the parameters step by step, starting from the neurons at the prediction, and pass the gradient down by each neurons consecutively until we reach the very first one. This is conveniently achieved by the &lt;em&gt;Chain rule&lt;/em&gt; from Calculus, and given the fact that the whole neural net is built from easy-to-differentiate functions (linear, or the sigmoid,..), we can work our way backward!&lt;/p&gt;

&lt;p&gt;Another thing that contributes to the success of this algorithm is the &lt;code&gt;autodiff&lt;/code&gt;, which stands for &lt;em&gt;automatic differentiation&lt;/em&gt;, actually Andrej also has a video about this and maybe I will do it later in another blog.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fmaj6ab647rfnac00tkgi.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fmaj6ab647rfnac00tkgi.png" alt="Illustration of Backpropagation" width="800" height="400"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;You can watch the video from 3blue1brown, I really admire him!&lt;br&gt;
&lt;a href="https://www.youtube.com/watch?v=Ilg3gGewQ5U" rel="noopener noreferrer"&gt;Backpropagation (Intuitive)&lt;/a&gt;&lt;br&gt;
&lt;a href="https://www.youtube.com/watch?v=tIeHLnjs5U8" rel="noopener noreferrer"&gt;Backpropagation (Calculus)&lt;/a&gt;&lt;/p&gt;
&lt;h3&gt;
  
  
  Implementation on the code
&lt;/h3&gt;

&lt;p&gt;Theory's enough, let's turn back to our main project. Actually, there are not so much things that we can do, because the library already includes some really convenient functions for us. Well, at least we understand the underlying process.&lt;/p&gt;

&lt;p&gt;We start by backpropagate just one time. Remember our forward pass?&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;  &lt;span class="n"&gt;xenc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;one_hot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# input to the network: one-hot encoding
&lt;/span&gt;  &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;xenc&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt; &lt;span class="c1"&gt;# predict log-counts
&lt;/span&gt;  &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# counts, equivalent to N
&lt;/span&gt;  &lt;span class="n"&gt;probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# probabilities for next character
&lt;/span&gt;  &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;ys&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; 
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Now we just need to perform a backward pass:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt; &lt;span class="c1"&gt;# set to zero the gradient
&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# magical things happen when applying this function
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Remember to set the &lt;code&gt;requires_grad = True&lt;/code&gt; when you initialize the &lt;code&gt;W&lt;/code&gt;, then when we use the &lt;code&gt;backward()&lt;/code&gt;, it applies backpropagation across the whole training set. We can have a look at our &lt;code&gt;W.grad&lt;/code&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;tensor([[ 0.0121,  0.0020,  0.0025,  0.0008,  0.0034, -0.1975,  0.0005,  0.0046,
          0.0027,  0.0063,  0.0016,  0.0056,  0.0018,  0.0016,  0.0100,  0.0476,
          0.0121,  0.0005,  0.0050,  0.0011,  0.0068,  0.0022,  0.0006,  0.0040,
          0.0024,  0.0307,  0.0292],
        [-0.1970,  0.0017,  0.0079,  0.0020,  0.0121,  0.0062,  0.0217,  0.0026,
          0.0025,  0.0010,  0.0205,  0.0017,  0.0198,  0.0022,  0.0046,  0.0041,
          0.0082,  0.0016,  0.0180,  0.0106,  0.0093,  0.0062,  0.0010,  0.0066,
          0.0131,  0.0101,  0.0018],
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;As you can see, it is filled with some non-zero values, which means we have successfully perform backpropagation for this training set.&lt;/p&gt;

&lt;p&gt;Now we need to update our &lt;code&gt;W&lt;/code&gt; matrix by &lt;em&gt;subtracting&lt;/em&gt; the gradient, multiplied by the &lt;em&gt;learning-rate&lt;/em&gt; (imagine it like the length of the step we want to make)&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And we repeat the whole process until we satisfy with the result. Note that there is this thing called &lt;code&gt;epochs&lt;/code&gt;, which is simply the number of iterations we want to make.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# gradient descent
&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;100&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;k&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;epochs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

  &lt;span class="c1"&gt;# forward pass
&lt;/span&gt;  &lt;span class="n"&gt;xenc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;one_hot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;xs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# input to the network: one-hot encoding
&lt;/span&gt;  &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;xenc&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt; &lt;span class="c1"&gt;# predict log-counts
&lt;/span&gt;  &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# counts, equivalent to N
&lt;/span&gt;  &lt;span class="n"&gt;probs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# probabilities for next character
&lt;/span&gt;  &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;ys&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
  &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;())&lt;/span&gt;

  &lt;span class="c1"&gt;# backward pass
&lt;/span&gt;  &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;None&lt;/span&gt; &lt;span class="c1"&gt;# set to zero the gradient
&lt;/span&gt;  &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

  &lt;span class="c1"&gt;# update
&lt;/span&gt;  &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt; &lt;span class="o"&gt;*&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;grad&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And that's basically it! We should see how our model performs. Here is some of its last iterations:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;2.4908623695373535
2.4906723499298096
2.4904870986938477
2.4903063774108887
2.4901304244995117
2.489959478378296
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Wonderful! Actually it's even better than counting! So we're good now.&lt;/p&gt;

&lt;h3&gt;
  
  
  Sampling
&lt;/h3&gt;

&lt;p&gt;And for the last, here is how we sample names from this model, it's just the same stuff we do to sample from the previous bigram model:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# finally, sample from the 'neural net' model
&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Generator&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2147483647&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;

  &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
  &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
  &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;

    &lt;span class="c1"&gt;# ----------
&lt;/span&gt;    &lt;span class="c1"&gt;# BEFORE:
&lt;/span&gt;    &lt;span class="c1"&gt;#p = P[ix]
&lt;/span&gt;    &lt;span class="c1"&gt;# ----------
&lt;/span&gt;    &lt;span class="c1"&gt;# NOW:
&lt;/span&gt;    &lt;span class="n"&gt;xenc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;F&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;one_hot&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt; &lt;span class="n"&gt;num_classes&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;xenc&lt;/span&gt; &lt;span class="o"&gt;@&lt;/span&gt; &lt;span class="n"&gt;W&lt;/span&gt; &lt;span class="c1"&gt;# predict log-counts
&lt;/span&gt;    &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;exp&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="c1"&gt;# counts, equivalent to N
&lt;/span&gt;    &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;counts&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;keepdims&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="c1"&gt;# probabilities for next character
&lt;/span&gt;    &lt;span class="c1"&gt;# ----------
&lt;/span&gt;
    &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;multinomial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;num_samples&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;replacement&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
      &lt;span class="k"&gt;break&lt;/span&gt;
  &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;''&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Let's look at our results&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;cexze.
momasurailezityha.
konimittain.
llayn.
ka.
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;It's the same with the bigram ! Well, another dissapointing result. This is just a one-layer neural network, so may be it does it best for now.&lt;/p&gt;

&lt;h2&gt;
  
  
  Key notes
&lt;/h2&gt;

&lt;h3&gt;
  
  
  Bigram versus Neural Net
&lt;/h3&gt;

&lt;p&gt;The sample from both models are the same, means that even though they go different path, they still share huge similarities. In the neural net, the loss seems to be lower, which indicates its better performance.&lt;/p&gt;

&lt;p&gt;Even though neural nets can be hard to understand, but it really outperforms its counterpart (remember that this is just a plain vanilla one-layer neural net, imagine the big bro MLP and Transformers come in), and it is also flexible, we can apply it in numerous tasks.&lt;/p&gt;

&lt;h3&gt;
  
  
  Some insights from the W matrix
&lt;/h3&gt;

&lt;p&gt;So far we've treated this matrix as nothing but a complicated black box that is tweaked and tuned all the time to produce some satisfying results. But there is an intuitive way to understand these parameters, which can help us gain insights into what the models are actually doing.&lt;/p&gt;

&lt;p&gt;The parameter &lt;code&gt;w&lt;/code&gt; sometimes represents for the &lt;em&gt;importance&lt;/em&gt; or the &lt;em&gt;influence&lt;/em&gt; of a neuron to one another. As the next neuron is just a linear combination of the previous, we will have the function of numerous variables multiplied by it's corresponding weight. So when we weight is positive, this means that the neuron is gonna greatly influence the next, and vice versa.&lt;/p&gt;

&lt;p&gt;When we want the model to produce, for example, the &lt;code&gt;'.'&lt;/code&gt; that is the end of word, we may not want it to be greatly influenced by, let's say, the letter &lt;code&gt;'j'&lt;/code&gt;, because there is just a few names that end with &lt;code&gt;'j'&lt;/code&gt;, and we want it to be more influenced by &lt;code&gt;'n'&lt;/code&gt; or &lt;code&gt;'m'&lt;/code&gt; because those are the common ending characters. So what will we do ? We just need to lower the weight for &lt;code&gt;j&lt;/code&gt; and increase the weights for &lt;code&gt;n&lt;/code&gt; and &lt;code&gt;m&lt;/code&gt;. And that's what our model is trying to do under the hood.&lt;/p&gt;

&lt;h3&gt;
  
  
  Regularization
&lt;/h3&gt;

&lt;p&gt;If our model is confident about the prediction, that's exactly what we want, but is it good when the model is &lt;em&gt;over-confident&lt;/em&gt;? Not really, sometimes we allow for a little bit of uncertainty in order to make the model perform better in real life. That is to &lt;em&gt;prevent overfitting&lt;/em&gt; and we call that by a fancy name: &lt;strong&gt;Regularization&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Allowing for some uncertainty in the decision-making process is equivalent to smooth out the probability distribution. And, to your surprise, it's actually similar to the &lt;strong&gt;Smoothing technique&lt;/strong&gt; we discussed in the previous chapter. This time we will do it in a slightly different way.&lt;/p&gt;

&lt;p&gt;Notice what when the &lt;code&gt;W&lt;/code&gt; matrix is full of zeros, then the probability distribution would be uniform, so we will try to drive the values in that matrix down so we can get a more uniform distribution. This is done in the loss function, where now we have two objectives, and the technique is called &lt;strong&gt;&lt;em&gt;Weight decay Regularization&lt;/em&gt;&lt;/strong&gt;:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;probs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;arange&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;num&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;ys&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mf"&gt;0.01&lt;/span&gt;&lt;span class="o"&gt;*&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;W&lt;/span&gt;&lt;span class="o"&gt;**&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;mean&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And when you run the code again, you can see that the loss is a bit higher than our loss without regularization. It's a small tradeoff we made to ensure that the model would perform well in real-world setting. (You can search about Bias/Variance tradeoff)&lt;/p&gt;

&lt;p&gt;Regularization is actually an interesting topic that needs a blog post for itself (well just after 2 blogs my to-do list stacked considerably). There are numerous regularization techniques that possess their own beauty. You can search if you are curious.&lt;/p&gt;

&lt;h2&gt;
  
  
  Summarize
&lt;/h2&gt;

&lt;p&gt;We've come so far in our journey! We learned tons of things about the neural net, about Softmax function, about Gradient Descent and Backpropagation, and Regularization also. There is just a lot of things packed in this tiny session about a small, simple language model. We broaden a whole new horizon of knowledge just by diving deep enough, and we have some fun playing with the code along the way. Life is good, my friend! &lt;/p&gt;

&lt;p&gt;In the next chapter we will discuss about MLP (multi-layer perceptron), following Andrej series, and I think it would be a lot of fun!&lt;/p&gt;

&lt;p&gt;Stay tuned, and thanks for reading. Having a good day!&lt;/p&gt;

</description>
      <category>beginners</category>
      <category>deeplearning</category>
      <category>machinelearning</category>
    </item>
    <item>
      <title>BUILDING A BIGRAM LANGUAGE MODEL</title>
      <dc:creator>Hưng Lê Tiến</dc:creator>
      <pubDate>Thu, 13 Nov 2025 11:48:20 +0000</pubDate>
      <link>https://dev.to/blackbidz/building-a-bigram-language-model-48f6</link>
      <guid>https://dev.to/blackbidz/building-a-bigram-language-model-48f6</guid>
      <description>&lt;p&gt;&lt;em&gt;This is deeply inspired by Andrej Karpathy's series about makemore, actually this post is just the notes of Andrej's first video, together with some insights that I came up with.&lt;/em&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  An Introduction To N-gram Language Model
&lt;/h2&gt;

&lt;p&gt;N-gram language model is perhaps the simplest form of language model that any human could think of. Imagine the predictive text that pops up above your keyboard for each word that you are typing, which is just some plain predictions given the previous word or phrases. N-gram language model is just it! &lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;N-gram language model considers your &lt;em&gt;n-1&lt;/em&gt; previous words and make prediction about the next word in the sequence.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Some N-gram language models that we should know:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;
&lt;em&gt;Unigram language models&lt;/em&gt;: 1-gram, consider no previous words to make prediction, honestly this is a funny language model.&lt;/li&gt;
&lt;li&gt;
&lt;em&gt;Bigram language models&lt;/em&gt;: consider just the previous word, no more no less, this is the model that we will examine today.&lt;/li&gt;
&lt;li&gt;
&lt;em&gt;Trigram language models&lt;/em&gt;: look back the 2 previous words.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;In the world of powerful and ever-developing LLMs, N-gram language models are just children's toys. As you can immediately notice in the way it works, and later when we scrutinize the making of it, it does not really take into account the &lt;strong&gt;&lt;em&gt;context&lt;/em&gt;&lt;/strong&gt; of the sequence. Sometimes  it doesn't even know the meaning of each words, which is sad given the complex nature of language.&lt;/p&gt;

&lt;p&gt;So why should we stick with this overly-simplified language model? Well, everyone starts somewhere, mate. This should be a great start for learning about language models, and later we will gradually dive deeper, step by step, so feel free to skip this blog if you know too well.&lt;/p&gt;

&lt;h2&gt;
  
  
  How It Works (The Intuition Behind)
&lt;/h2&gt;

&lt;p&gt;Just by counting, no more no less.&lt;/p&gt;

&lt;p&gt;It is basically the main idea! We just need to have a huge text corpus to train on, and we count the &lt;strong&gt;relative frequency&lt;/strong&gt; of each combination of words. Next, we assign a kind of probability for upcoming words given the frequency we knew earlier, and our model is ready to go!&lt;/p&gt;

&lt;p&gt;If you use the training set including conversations of two people loving each other, for example, when the model sees &lt;strong&gt;"I"&lt;/strong&gt; it will confidently think that the next word would be &lt;strong&gt;"love"&lt;/strong&gt;, and when there is &lt;strong&gt;"love"&lt;/strong&gt; then the next should probably be &lt;strong&gt;"you"&lt;/strong&gt;. The model just works by pure memorization and a bit of probability.&lt;/p&gt;

&lt;p&gt;But don't criticize it too much, we can still get some insights from the model, like some &lt;strong&gt;syntactics nature&lt;/strong&gt; of language, e.g. Nouns and Adjectives often come after Verbs, or Verbs often come after Nouns.&lt;/p&gt;

&lt;p&gt;So let's make things complicated by doing some maths.&lt;/p&gt;

&lt;h2&gt;
  
  
  How It Works (The Math Behind)
&lt;/h2&gt;

&lt;p&gt;Considering the task of predicting the next word of a sequence &lt;strong&gt;"Your eyes are..."&lt;/strong&gt;&lt;br&gt;
Well, maybe your eyes are beautiful (are they?), but in a probabilistic sense, we should ask:&lt;/p&gt;

&lt;p&gt;&lt;em&gt;What is the probability that the word &lt;strong&gt;"beautiful"&lt;/strong&gt; appears after the sequence of &lt;strong&gt;"Your eyes are"&lt;/strong&gt;?&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;Mathematicians have a nice way to ask this question, using conditional probability:&lt;/p&gt;

&lt;p&gt;

&lt;/p&gt;
&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;P(beautiful∣Your eyes are)
        P(\text{beautiful}|\text{Your eyes are}) 
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;beautiful&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;Your eyes are&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;


&lt;p&gt;Another way to interpret this is that: In our text corpus, among all of the cases that "Your eyes are" appears, how many times that the word "beautiful" is put after it?&lt;/p&gt;


&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;P(beautiful∣Your eyes are)=C(Your eyes are beautiful)C(Your eyes are)
        P(\text{beautiful}|\text{Your eyes are}) = \frac{C(\text{Your eyes are beautiful})}{C(\text{Your eyes are})} 
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;beautiful&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;Your eyes are&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;C&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;Your eyes are&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;C&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;Your eyes are beautiful&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;


&lt;p&gt;&lt;strong&gt;But this won't work well!&lt;/strong&gt;, you may think. Well, you are right. We cannot count like that because we won't get a good estimate for medium-sized corpus. Another reason is that language is &lt;strong&gt;creative&lt;/strong&gt;, which means that new things are invented all the time. So all in all, the approach should be more flexible.&lt;/p&gt;

&lt;p&gt;A more clever way is calculating the probability in a recursive way, and get the product of them using the &lt;strong&gt;&lt;em&gt;Chain rule&lt;/em&gt;&lt;/strong&gt; of Probability:&lt;/p&gt;


&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;P(w1:n)=P(w1)P(w2∣w1)P(w3∣w1:2)...P(wn∣w1:n−1)=∏k=1nP(wk∣w1:k−1)
        P(w_{1:n}) = P(w_1)P(w_2|w_1)P(w_3|w_{1:2})...P(w_n|w_{1:n-1}) = \prod\limits_{k=1}^n P(w_k|w_{1:k-1}) 
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;span class="mrel mtight"&gt;:&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;3&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;span class="mrel mtight"&gt;:&lt;/span&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mord"&gt;...&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;span class="mrel mtight"&gt;:&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mop op-limits"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;k&lt;/span&gt;&lt;span class="mrel mtight"&gt;=&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="mop op-symbol large-op"&gt;∏&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;k&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;span class="mrel mtight"&gt;:&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;k&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;


&lt;p&gt;Well, actually it doesn't solve the problem, and it even makes the problem worse as we need to keep track of a bunch of long strings before we come up with the final ones. And this is where our friend Markov comes in.&lt;/p&gt;

&lt;h2&gt;
  
  
  The Markov Assumption
&lt;/h2&gt;

&lt;blockquote&gt;
&lt;p&gt;Rather than considering the whole history of our sequence, we can approximate it by just looking at the last few words.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Markov didn't say that, but that was his main idea. In our example, rather than taking the whole sentence "Your eyes are so" into account, we just need to look at the word "so" to predict (of course we're using the bigram model, if it is trigram then we need to consider 2 preceding words).&lt;/p&gt;

&lt;p&gt;To wrap up, we made this assumption (for the bigram):&lt;br&gt;

&lt;/p&gt;
&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;P(wn∣w1:n−1)≈P(wn∣wn−1)
P(w_n|w_{1:n-1}) \approx P(w_n|w_{n-1})
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;span class="mrel mtight"&gt;:&lt;/span&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;≈&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;


&lt;p&gt;And with that in mind, we can easily calculate the probability that we want, just by counting:&lt;br&gt;

&lt;/p&gt;
&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;P(wn∣wn−1)=C(wn−1wn)∑wC(wn−1wn)=C(wn−1wn)C(wn−1)
P (w_n|w_{n-1}) = \frac{C(w_{n-1}w_n)}{\sum_w C(w_{n-1}w_n)} = \frac{C(w_{n-1}w_n)}{C(w_{n-1})} 
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;∣&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mop"&gt;&lt;span class="mop op-symbol small-op"&gt;∑&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;w&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mord mathnormal"&gt;C&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;C&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;C&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;C&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;span class="mbin mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;


&lt;p&gt;In other words, we calculate the ratio of the frequency of the combination we think and the frequency of the preceding word, and that's what we called &lt;strong&gt;&lt;em&gt;relative frequency&lt;/em&gt;&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;It's getting boring talking about these theoretical concepts, let's jump right in and play around with the code.&lt;/p&gt;

&lt;h2&gt;
  
  
  Making Makemore - A Bigram Language Model
&lt;/h2&gt;

&lt;p&gt;You should check this video out, this is the main inspiration for this post, and shout out to Andrej Karpathy !&lt;br&gt;
&lt;a href="https://www.youtube.com/watch?v=PaCmpygFfXo&amp;amp;t=3847s" rel="noopener noreferrer"&gt;The spelled-out intro to language modeling: building makemore&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;To start, makemore is a simple language model that generate new versions of any data that is feed into it, in this case we are dealing with names, so you can think that makemore is a kind of name-generating language model.&lt;/p&gt;

&lt;p&gt;Makemore works just like a kind of bigram LM, but instead of looking at the previous word (it can't), it looks at the previous character and predicting the next one. &lt;/p&gt;

&lt;p&gt;We first load and explore our dataset&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;words&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;open&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;names.txt&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;r&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;read&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;splitlines&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;[:&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;So we have this, which is some American names&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Now we create our own bigram&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Creating the bigram, just a kind of data preprocessing
&lt;/span&gt;
&lt;span class="c1"&gt;# Make a dictionary to store all the pairs
&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{}&lt;/span&gt;

&lt;span class="c1"&gt;# We have the zip function for convenience pairing, and we should
# signal the start/end of words by some additional characters
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;chs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;S&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;E&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]):&lt;/span&gt;
    &lt;span class="n"&gt;bigram&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;bigram&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;bigram&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Let's look at our bigram, shall we?&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# We sort the pairs by their frequency
&lt;/span&gt;&lt;span class="nf"&gt;sorted&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;b&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="k"&gt;lambda&lt;/span&gt; &lt;span class="n"&gt;kv&lt;/span&gt; &lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;kv&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;])[:&lt;/span&gt;&lt;span class="mi"&gt;5&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

&lt;span class="c1"&gt;# Here's the result
&lt;/span&gt;&lt;span class="p"&gt;[((&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;E&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;6763&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
 &lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;a&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;E&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;6640&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
 &lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;a&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;n&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;5438&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
 &lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;S&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;a&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;4410&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
 &lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;e&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;E&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="mi"&gt;3983&lt;/span&gt;&lt;span class="p"&gt;)]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;h3&gt;
  
  
  Tensor
&lt;/h3&gt;

&lt;p&gt;For better data manipulation, we use Tensor from Pytorch, which is a special data structure that does great in ML/DL projects.&lt;/p&gt;

&lt;p&gt;First we need to create a Tensor of our own:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="c1"&gt;# Create the tensor for our project
&lt;/span&gt;
&lt;span class="c1"&gt;# We have 26 letters in the alphabet and 2 charr &amp;lt;S&amp;gt; and &amp;lt;E&amp;gt; -&amp;gt; 28
&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;28&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;28&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;int32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# This should be a 27x27 dimensional array storing 32-bit integers, normally the array will store the floating points number so we need to specify the dtype when calling the function
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;

&lt;h3&gt;
  
  
  Indexing characters - from strings to integers
&lt;/h3&gt;

&lt;p&gt;Computers talk by numbers, not words, so we should find a way to "encode" our characters into integers. The simplest way is perhaps indexing each character by its order in the alphabet, and each pair is just a tuple of 2 integers.&lt;/p&gt;

&lt;p&gt;Here's how it goes:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Creating a lookup table from characters to integers
&lt;/span&gt;&lt;span class="n"&gt;chars&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;sorted&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;set&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;''&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;))))&lt;/span&gt;

&lt;span class="c1"&gt;# Map from char to int, indexing
&lt;/span&gt;&lt;span class="n"&gt;stoi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chars&lt;/span&gt;&lt;span class="p"&gt;)}&lt;/span&gt;

&lt;span class="c1"&gt;# Remember the additional characters
&lt;/span&gt;&lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;S&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;26&lt;/span&gt;
&lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;E&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;27&lt;/span&gt;

&lt;span class="c1"&gt;# Take out our previous code
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;chs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;S&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;&amp;lt;S&amp;gt;&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]):&lt;/span&gt;
    &lt;span class="n"&gt;ix1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;ix2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;ix2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

&lt;span class="c1"&gt;# Reverse mapping
&lt;/span&gt;
&lt;span class="n"&gt;itos&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;items&lt;/span&gt;&lt;span class="p"&gt;()}&lt;/span&gt;

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;It's basically done! We can use &lt;code&gt;matplotlib&lt;/code&gt; to contemplate our beautifully crafted bigram.&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;matplotlib.pyplot&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;plt&lt;/span&gt;
&lt;span class="o"&gt;%&lt;/span&gt;&lt;span class="n"&gt;matplotlib&lt;/span&gt; &lt;span class="n"&gt;inline&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;cmap&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Blues&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;28&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;j&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;28&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
      &lt;span class="n"&gt;chstr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
      &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;chstr&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;center&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;va&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;bottom&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gray&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
      &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;ha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;center&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;va&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;top&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gray&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axis&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;off&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;And there you have it!&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fxqmeayft33zfkzv27ebk.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fxqmeayft33zfkzv27ebk.png" alt="Our lovely bigram"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h3&gt;
  
  
  A small modification
&lt;/h3&gt;

&lt;p&gt;As you can see from our bigram, the last row is all zeros, and so is the columns near the end. What happened?&lt;/p&gt;

&lt;p&gt;Notice when we use &lt;code&gt;&amp;lt;S&amp;gt;&lt;/code&gt; and &lt;code&gt;&amp;lt;E&amp;gt;&lt;/code&gt; for denoting the start and end of words, there will never be cases where the &lt;code&gt;&amp;lt;S&amp;gt;&lt;/code&gt; starts at the second position, or the &lt;code&gt;&amp;lt;E&amp;gt;&lt;/code&gt; follows behind a character. And that is exactly why we have a row and a column of zeros.&lt;/p&gt;

&lt;p&gt;To solve this, according to our friend Andrej in the video, we should replace these two by just one dot. Just a &lt;code&gt;.&lt;/code&gt; for both ending and starting position.&lt;/p&gt;

&lt;p&gt;So we should make some changes to our code&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Now the tensor is just 27x27
&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;((&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;dtype&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;int32&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# We also want to set the dot to 0 (personal preference)
# And with that we have to adjust our dictionary
&lt;/span&gt;&lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
&lt;span class="n"&gt;stoi&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;s&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chars&lt;/span&gt;&lt;span class="p"&gt;)}&lt;/span&gt;

&lt;span class="c1"&gt;# Now change the chs
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;chs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]):&lt;/span&gt;
    &lt;span class="n"&gt;ix1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;ix2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;ix2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

&lt;span class="c1"&gt;# And adjust the range of i,j, then we're done
&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;figure&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;figsize&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="mi"&gt;16&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;imshow&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;cmap&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Blues&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;j&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;27&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
      &lt;span class="n"&gt;chstr&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
      &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;chstr&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;center&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;va&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;bottom&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gray&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
      &lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;text&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;j&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;ha&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;center&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;va&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="s"&gt;top&lt;/span&gt;&lt;span class="sh"&gt;"&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;color&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;gray&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;plt&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;axis&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;off&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;So now we have our final result:&lt;br&gt;
&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fbntsquv4plr0r1y77xym.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Farticles%2Fbntsquv4plr0r1y77xym.png" alt="Perfect!"&gt;&lt;/a&gt;&lt;/p&gt;
&lt;h3&gt;
  
  
  Calculating the probability
&lt;/h3&gt;

&lt;p&gt;Now the serious things kick in, we need to calculate the relative frequency for each pairs of characters that we have. The most convenience way is creating a matrix of probability &lt;code&gt;P&lt;/code&gt; that is of the same size as &lt;code&gt;N&lt;/code&gt;, where each entry represents the probability of the corresponding pair stored in &lt;code&gt;N&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;So how can we calculate the probability, really? Remember the formula where we divide the frequency of the pair with the frequency of the &lt;em&gt;preceding word&lt;/em&gt;? To your surprise, it is actually the count of the pair divided by the &lt;strong&gt;sum of its row&lt;/strong&gt; in the frequency matrix (&lt;code&gt;N&lt;/code&gt;). Take a moment to think about it.&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;
&lt;span class="c1"&gt;# Create the P matrix
&lt;/span&gt;&lt;span class="n"&gt;P&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="c1"&gt;# Take the sum for each row
# Notice that it is really convenient when using tensor
&lt;/span&gt;
&lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;keepdim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# And then we normalize every entry
# Notice, again, how clean the code is
&lt;/span&gt;
&lt;span class="n"&gt;P&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;P&lt;/span&gt; &lt;span class="o"&gt;/&lt;/span&gt; &lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;sum&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;keepdim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Just a small note here, why should we set &lt;code&gt;keepdim = True&lt;/code&gt; ? This can take a whole blog post to explain (and probably I will do it if I'm not lazy). This has to do with something called &lt;strong&gt;&lt;em&gt;Broadcasting&lt;/em&gt;&lt;/strong&gt;, which is a really convenient data manipulation and it is used all over the place in data processing. You can check the video from Andrej if you're curious, he talked really deeply about this. For now let us just skip this small detail and move on to the next part.&lt;/p&gt;
&lt;h3&gt;
  
  
  Sampling from our model
&lt;/h3&gt;

&lt;p&gt;We've created our probabilistic model! May be it's not clear to you, but we've done all the essential part. Now you might ask "How can we get new names from this model?". You need to take a character, possibly a random one, and work recursively to predict the next. What I mean when saying "work recursively to predict" is that: when you have a character, you have to look at the &lt;strong&gt;row&lt;/strong&gt; where it is the preceding character and &lt;strong&gt;sample&lt;/strong&gt; from that distribution to get the next char, and when we get the next word we will do that again to get the upcoming char until we encounter the dot, which is the end of word.&lt;/p&gt;

&lt;p&gt;All of that is neatly nested in our code below:&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# We need to use a Generator in Pytorch to get the samples
&lt;/span&gt;&lt;span class="n"&gt;g&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Generator&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;manual_seed&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2147483647&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="c1"&gt;# Get the first 50 names
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;50&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
  &lt;span class="n"&gt;out&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[]&lt;/span&gt;
  &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;
  &lt;span class="k"&gt;while&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
   &lt;span class="c1"&gt;# Get the distribution of the word and sample from it
&lt;/span&gt;    &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;multinomial&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;&lt;span class="n"&gt;num_samples&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;replacement&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;generator&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;g&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;itos&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;
   &lt;span class="c1"&gt;# End the loop when we see a dot
&lt;/span&gt;    &lt;span class="k"&gt;if&lt;/span&gt; &lt;span class="n"&gt;ix&lt;/span&gt; &lt;span class="o"&gt;==&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
      &lt;span class="k"&gt;break&lt;/span&gt;
  &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;''&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;join&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;out&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;If you use the same seed in the Generator with me, you will have this sequence of names&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;cexze.
momasurailezitynn.
konimittain.
llayn.
ka.
da.
staiyaubrtthrigotai.
...
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;That is dissapointing ! But there are some names that are acceptable though, and we cannot expect too much from this simple model. &lt;/p&gt;
&lt;h2&gt;
  
  
  Evaluation
&lt;/h2&gt;

&lt;p&gt;We need to use a kind of metric to evaluate this model, and don't even think about MSE or MAE or RMSE, those are not even for probability, let alone languages.&lt;/p&gt;

&lt;p&gt;When dealing with probability, there are some evaluation metrics that should come to your mind. One of which is the &lt;strong&gt;&lt;em&gt;Maximum Likelihood Estimation&lt;/em&gt;&lt;/strong&gt;.&lt;/p&gt;
&lt;h3&gt;
  
  
  Maximum Likelihood Estimation (MLE)
&lt;/h3&gt;

&lt;p&gt;This is also a topic that should have a blog post for itself, given the wide application of it in ML/DL and even beyond that scope. To be honest, I don't know too well about this topic, so I will just scratch the surface in this part.&lt;/p&gt;

&lt;p&gt;The &lt;em&gt;likelihood&lt;/em&gt; of the data is a kind of a function, it is different from &lt;em&gt;probability&lt;/em&gt;, and in some specific tasks, we might need to maximize this function to get the optimized loss. Considering the likelihood function, you just need to know that it is the product of a bunch of probability.&lt;/p&gt;

&lt;p&gt;But there is a problem with scaling if you notice: Probability always ranges from 0 to 1, so when we scale the data, which means that we multiply numerous values of probability together, we would get a &lt;strong&gt;extremely small&lt;/strong&gt; result, which is bad. That is called &lt;strong&gt;numerical underflow&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;To solve the problem, we just need to use the &lt;em&gt;log&lt;/em&gt; of the likelihood instead. And we also need to calculate the &lt;em&gt;negative log likelihood&lt;/em&gt;, well it is not really useful in today's task, but it will be of immense importance in the next blog (when we use a neural network). And that's it!&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Calculate the log_likelihood
&lt;/span&gt;&lt;span class="n"&gt;log_likelihood&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mf"&gt;0.0&lt;/span&gt;
&lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;0&lt;/span&gt;

&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;
  &lt;span class="n"&gt;chs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ch2&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]):&lt;/span&gt;
    &lt;span class="n"&gt;ix1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;ix2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;prob&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ix2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;logprob&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;log_likelihood&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;logprob&lt;/span&gt;
    &lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;

&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;log_likelihood&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;nll&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;log_likelihood&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;nll&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;nll&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;So we have our log, and we should normalize it by the number of words.&lt;br&gt;
&lt;/p&gt;
&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;log_likelihood=tensor(-559891.7500)
nll=tensor(559891.7500)
2.454094171524048
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;


&lt;p&gt;Not really bad ! This is the best thing we can do in this char-to-char predicting task, perhaps we expect too much from it.&lt;/p&gt;
&lt;h2&gt;
  
  
  Perplexity
&lt;/h2&gt;

&lt;p&gt;Another metric we would use is Perplexity (PP or PPL), which shares the idea with the definition of "Entropy".&lt;/p&gt;

&lt;p&gt;We think of this metric as the &lt;em&gt;level of surprise&lt;/em&gt; the model gets when generating the new word. If the model does well and assigns a high probability to the word, then it would be &lt;em&gt;less surprised&lt;/em&gt; when the word came, and vice versa. So a good model is the one who is &lt;strong&gt;&lt;em&gt;least surprised&lt;/em&gt;&lt;/strong&gt; by the test set (or we can say it is less perplexed, which I think is the main idea behind the term Perplexity).&lt;/p&gt;

&lt;p&gt;So how can we measure surprise? Note that when assigning a  &lt;em&gt;high probability&lt;/em&gt;, the model will be &lt;em&gt;less surprised&lt;/em&gt;, and vice versa. So we can think of an &lt;strong&gt;inverse relationship&lt;/strong&gt;, and that’s exactly scientists do! They take the inverse of the probability, together with a kind of normalization by the length.&lt;/p&gt;

&lt;p&gt;Here's the formula:&lt;br&gt;

&lt;/p&gt;
&lt;div class="katex-element"&gt;
  &lt;span class="katex-display"&gt;&lt;span class="katex"&gt;&lt;span class="katex-mathml"&gt;perplexity(W)=P(w1w2...wn)−1N=1P(w1w2...wn)N
\text{perplexity}(W) = {P(w_1w_2...w_n)}^{\frac{-1}{N}} = \sqrt[N]{\frac{1}{P(w_1w_2...w_n)}}
&lt;/span&gt;&lt;span class="katex-html"&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord text"&gt;&lt;span class="mord"&gt;perplexity&lt;/span&gt;&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord mathnormal"&gt;W&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;...&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mopen nulldelimiter sizing reset-size3 size6"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size3 size1 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;N&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line mtight"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size3 size1 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mtight"&gt;−&lt;/span&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter sizing reset-size3 size6"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;span class="mrel"&gt;=&lt;/span&gt;&lt;span class="mspace"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="base"&gt;&lt;span class="strut"&gt;&lt;/span&gt;&lt;span class="mord sqrt"&gt;&lt;span class="root"&gt;&lt;span class="vlist-t"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size1 mtight"&gt;&lt;span class="mord mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;N&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span class="svg-align"&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;&lt;span class="mopen nulldelimiter"&gt;&lt;/span&gt;&lt;span class="mfrac"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;P&lt;/span&gt;&lt;span class="mopen"&gt;(&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mtight"&gt;2&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mord"&gt;...&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord mathnormal"&gt;w&lt;/span&gt;&lt;span class="msupsub"&gt;&lt;span class="vlist-t vlist-t2"&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="sizing reset-size6 size3 mtight"&gt;&lt;span class="mord mathnormal mtight"&gt;n&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose"&gt;)&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="frac-line"&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="mord"&gt;&lt;span class="mord"&gt;1&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="mclose nulldelimiter"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span&gt;&lt;span class="pstrut"&gt;&lt;/span&gt;&lt;span class="hide-tail"&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-s"&gt;​&lt;/span&gt;&lt;/span&gt;&lt;span class="vlist-r"&gt;&lt;span class="vlist"&gt;&lt;span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;&lt;/span&gt;
&lt;/div&gt;



&lt;p&gt;Now as we've come this far, I shall tell you this fact: &lt;/p&gt;

&lt;blockquote&gt;
&lt;p&gt;The Perplexity is actually the exponential of the negative log-likelihood. Those are &lt;strong&gt;&lt;em&gt;the same&lt;/em&gt;&lt;/strong&gt;.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Till we meet again.&lt;/p&gt;

&lt;h2&gt;
  
  
  Smoothing
&lt;/h2&gt;

&lt;p&gt;Our code is not really perfect yet! What if we need to calculate the log-likelihood of a sequence that is never seen before? Let's see:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# SMOOTHING
&lt;/span&gt;
&lt;span class="c1"&gt;# Check this word and calculate the loss
&lt;/span&gt;&lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;bidjz&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]:&lt;/span&gt;
  &lt;span class="n"&gt;chs&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="nf"&gt;list&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;+&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;.&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
  &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ch2&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;zip&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;chs&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:]):&lt;/span&gt;
    &lt;span class="n"&gt;ix1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;ix2&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;stoi&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;prob&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;P&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;ix1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;ix2&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
    &lt;span class="n"&gt;logprob&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;log&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;log_likelihood&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="n"&gt;logprob&lt;/span&gt;
    &lt;span class="n"&gt;n&lt;/span&gt; &lt;span class="o"&gt;+=&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;
    &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;ch1&lt;/span&gt;&lt;span class="si"&gt;}{&lt;/span&gt;&lt;span class="n"&gt;ch2&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt;: &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;prob&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;logprob&lt;/span&gt;&lt;span class="si"&gt;:&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="n"&gt;f&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;log_likelihood&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;nll&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="n"&gt;log_likelihood&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;nll&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sa"&gt;f&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="si"&gt;{&lt;/span&gt;&lt;span class="n"&gt;nll&lt;/span&gt;&lt;span class="o"&gt;/&lt;/span&gt;&lt;span class="n"&gt;n&lt;/span&gt;&lt;span class="si"&gt;}&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Let's look at the result&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;b: 0.0408 -3.1998
bi: 0.0820 -2.5005
id: 0.0249 -3.6946
dj: 0.0016 -6.4146
jz: 0.0000 -inf
z.: 0.0667 -2.7072
log_likelihood=tensor(-inf)
nll=tensor(inf)
inf
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The log_likelihood goes to negative infinity! This is because when the model encounter a pair that it's never seen before, the probability turns into zero, and hence the likelihood. Of course we don't want that, so we need to cancel out all of the zeros.&lt;/p&gt;

&lt;p&gt;The simplest way to do that is adding 1 to every entry in the &lt;code&gt;N&lt;/code&gt; matrix (which is called &lt;strong&gt;Laplace Smoothing&lt;/strong&gt;, well, guess how brilliant he felt when he came up with that !).&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="c1"&gt;# Just a small modification here
&lt;/span&gt;&lt;span class="n"&gt;P&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;N&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;float&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;And when we run our code again, we should have:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;.b: 0.0408 -3.1999
bi: 0.0816 -2.5061
id: 0.0249 -3.6939
dj: 0.0018 -6.3141
jz: 0.0003 -7.9817
z.: 0.0664 -2.7122
log_likelihood=tensor(-559977.9375)
nll=tensor(559977.9375)
2.454407215118408
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;We're good for now!&lt;/p&gt;

&lt;h2&gt;
  
  
  Summarize
&lt;/h2&gt;

&lt;p&gt;For this model, actually we don't have much to say, but the reason why this post is so long is that I want to go deep into some relevant topics as well, like the Markov assumption, the Tensor and Broadcasting, or playing around with the model. All in all, I just want to wrap up all of the knowledge I learned in the video from Andrej, and share these cool things with you, my friend.&lt;/p&gt;

&lt;p&gt;Thanks for reading, having a good day !&lt;/p&gt;

</description>
      <category>tutorial</category>
      <category>beginners</category>
      <category>ai</category>
      <category>machinelearning</category>
    </item>
  </channel>
</rss>
