<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Domas Bitvinskas</title>
    <description>The latest articles on DEV Community by Domas Bitvinskas (@nedomas).</description>
    <link>https://dev.to/nedomas</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F166553%2F06b2272b-e177-4385-9e5b-2e307d56e2c8.jpeg</url>
      <title>DEV Community: Domas Bitvinskas</title>
      <link>https://dev.to/nedomas</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/nedomas"/>
    <language>en</language>
    <item>
      <title>ELU Activation Function</title>
      <dc:creator>Domas Bitvinskas</dc:creator>
      <pubDate>Wed, 22 Jul 2020 08:37:52 +0000</pubDate>
      <link>https://dev.to/nedomas/elu-activation-function-12dl</link>
      <guid>https://dev.to/nedomas/elu-activation-function-12dl</guid>
      <description>&lt;p&gt;Exponential Linear Unit (ELU) is a popular activation function that speeds up learning and produces more accurate results. This article is an introduction to ELU and its position when compared to other popular activation functions. It also includes an interactive example and usage with &lt;a href="https://pytorch.org" rel="noopener noreferrer"&gt;PyTorch&lt;/a&gt; and &lt;a href="https://www.tensorflow.org" rel="noopener noreferrer"&gt;Tensorflow&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://www.bioinf.jku.at/people/clevert/" rel="noopener noreferrer"&gt;Djork-Arné Clevert&lt;/a&gt;, &lt;a href="https://twitter.com/TomUnterthiner" rel="noopener noreferrer"&gt;Thomas Unterthiner&lt;/a&gt;, &lt;a href="https://www.jku.at/en/institute-for-machine-learning/about-us/team/sepp-hochreiter/" rel="noopener noreferrer"&gt;Sepp Hochreiter&lt;/a&gt; introduced ELU in &lt;a href="https://arxiv.org/abs/1511.07289v5" rel="noopener noreferrer"&gt;Nov 2015&lt;/a&gt;. It outperformed ReLU-based CIFAR-100 networks at the time. To this day, ELUs are still popular among Machine Learning engineers and are well studied by now.&lt;/p&gt;

&lt;h2&gt;
  
  
  What is ELU?
&lt;/h2&gt;

&lt;p&gt;ELU is an activation function based on ReLU that has an extra alpha constant (α) that defines function smoothness when inputs are negative. Play with an interactive example below to understand how α influences the curve for the negative part of the function.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://closeheat.com/blog/elu-activation-function" rel="noopener noreferrer"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Fo6tzgptt25vbom5soi1g.png" alt="ELU activation" width="800" height="409"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  ELU calculation
&lt;/h3&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2F09axyqzxbqjtejh88sjp.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2F09axyqzxbqjtejh88sjp.png" alt="ELU formula" width="606" height="174"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;The ELU output for positive input is the input (identity). If the input is negative, the output curve is slightly smoothed towards the alpha constant (α). The higher the alpha constant, the more negative the output for negative inputs gets.&lt;/p&gt;

&lt;h3&gt;
  
  
  ELU vs ReLU
&lt;/h3&gt;

&lt;p&gt;ELU and ReLU are the most popular activation functions used. Here are the advantages and disadvantages of using it when compared to other popular activation functions.&lt;/p&gt;

&lt;h4&gt;
  
  
  Advantages of ELU
&lt;/h4&gt;

&lt;ul&gt;
&lt;li&gt;Tend to converge faster than ReLU (because mean ELU activations are closer to zero)&lt;/li&gt;
&lt;li&gt;Better generalization performance than ReLU&lt;/li&gt;
&lt;li&gt;Fully continuous&lt;/li&gt;
&lt;li&gt;Fully differentiable&lt;/li&gt;
&lt;li&gt;Does not have a &lt;em&gt;vanishing gradients&lt;/em&gt; problem&lt;/li&gt;
&lt;li&gt;Does not have an &lt;em&gt;exploding gradients&lt;/em&gt; problem&lt;/li&gt;
&lt;li&gt;Does not have a &lt;em&gt;dead relu&lt;/em&gt; problem&lt;/li&gt;
&lt;/ul&gt;

&lt;h4&gt;
  
  
  Disadvantages of ELU
&lt;/h4&gt;

&lt;ul&gt;
&lt;li&gt;Slower to compute (because of non-linearity for negative input values)&lt;/li&gt;
&lt;/ul&gt;

&lt;blockquote&gt;
&lt;p&gt;ELU is slower to compute, but ELU compensates this by faster convergence during training. During test time ELU is slower to compute than ReLU though.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;h2&gt;
  
  
  Usage in machine learning frameworks
&lt;/h2&gt;

&lt;h3&gt;
  
  
  PyTorch ELU usage
&lt;/h3&gt;

&lt;p&gt;To use ELU in PyTorch, use &lt;code&gt;torch.nn.ELU&lt;/code&gt; function. Here's an example model that uses ELU:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;layer1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Conv3d&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;in_channels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;out_channels&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kernel_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
            &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ELU&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;layer1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Note that &lt;code&gt;alpha&lt;/code&gt; is &lt;code&gt;1&lt;/code&gt; by default. To learn more, read &lt;a href="https://pytorch.org/docs/master/generated/torch.nn.ELU.html" rel="noopener noreferrer"&gt;PyTorch ELU documentation&lt;/a&gt;.&lt;/p&gt;

&lt;h3&gt;
  
  
  TensorFlow ELU usage
&lt;/h3&gt;

&lt;p&gt;The easiest way to use ELU in TensorFlow is to use Keras layers. Example TensorFlow model below that uses ELU:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;tensorflow&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;keras&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;layer1&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;keras&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;models&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Sequential&lt;/span&gt;&lt;span class="p"&gt;([&lt;/span&gt;
            &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;keras&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;layers&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Conv2D&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;kernel_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;input_shape&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;6&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;)),&lt;/span&gt;
            &lt;span class="n"&gt;tf&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;keras&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;layers&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ELU&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;alpha&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;2.0&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="p"&gt;])&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;call&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;layer1&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Note that &lt;code&gt;alpha&lt;/code&gt; is &lt;code&gt;1&lt;/code&gt; by default. More details in &lt;a href="https://www.tensorflow.org/api_docs/python/tf/keras/layers/ELU" rel="noopener noreferrer"&gt;TensorFlow ELU documentation&lt;/a&gt;.&lt;/p&gt;

&lt;h2&gt;
  
  
  Next steps
&lt;/h2&gt;

&lt;p&gt;Now you know how ELU works and where it stands compared to other popular activation functions.&lt;/p&gt;

&lt;p&gt;Here's what you can do next:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Learn more about &lt;a href="https://ml-cheatsheet.readthedocs.io/en/latest/activation_functions.html" rel="noopener noreferrer"&gt;other activation functions&lt;/a&gt;.&lt;/li&gt;
&lt;li&gt;Try different values for &lt;code&gt;alpha&lt;/code&gt; and see how it influences your model accuracy and training speed.&lt;/li&gt;
&lt;li&gt;Use ELU in your machine learning models (try &lt;a href="https://closeheat.com/blog/pytorch-lstm-text-generation-tutorial" rel="noopener noreferrer"&gt;LSTMs&lt;/a&gt;).&lt;/li&gt;
&lt;/ul&gt;

</description>
      <category>machinelearning</category>
      <category>python</category>
      <category>pytorch</category>
      <category>tensorflow</category>
    </item>
    <item>
      <title>PyTorch LSTM: Text Generation Tutorial</title>
      <dc:creator>Domas Bitvinskas</dc:creator>
      <pubDate>Mon, 29 Jun 2020 18:45:41 +0000</pubDate>
      <link>https://dev.to/nedomas/pytorch-lstm-text-generation-tutorial-2nf5</link>
      <guid>https://dev.to/nedomas/pytorch-lstm-text-generation-tutorial-2nf5</guid>
      <description>&lt;p&gt;Long Short Term Memory (LSTM) is a popular Recurrent Neural Network (RNN) architecture. This tutorial covers using LSTMs on PyTorch for generating text; in this case - pretty lame jokes.&lt;/p&gt;

&lt;p&gt;For this tutorial you need:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Basic familiarity with Python, PyTorch, and machine learning&lt;/li&gt;
&lt;li&gt;A locally installed &lt;a href="https://www.python.org" rel="noopener noreferrer"&gt;Python&lt;/a&gt; v3+,
&lt;a href="https://pytorch.org" rel="noopener noreferrer"&gt;PyTorch&lt;/a&gt; v1+, &lt;a href="https://numpy.org" rel="noopener noreferrer"&gt;NumPy&lt;/a&gt; v1+&lt;/li&gt;
&lt;/ul&gt;

&lt;h2&gt;
  
  
  What is LSTM?
&lt;/h2&gt;

&lt;p&gt;LSTM is a variant of RNN used in deep learning. You can use LSTMs if you are working on sequences of data.&lt;/p&gt;

&lt;p&gt;Here are the most straightforward use-cases for LSTM networks you might be familiar with:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Time series forecasting (for example, stock prediction)&lt;/li&gt;
&lt;li&gt;Text generation&lt;/li&gt;
&lt;li&gt;Video classification&lt;/li&gt;
&lt;li&gt;Music generation&lt;/li&gt;
&lt;li&gt;Anomaly detection&lt;/li&gt;
&lt;/ul&gt;

&lt;h3&gt;
  
  
  RNN
&lt;/h3&gt;

&lt;p&gt;Before you start using LSTMs, you need to understand how RNNs work.&lt;/p&gt;

&lt;p&gt;RNNs are neural networks that are good with sequential data. It can be video, audio, text, stock market time series or even a single image cut into a sequence of its parts.&lt;/p&gt;

&lt;p&gt;Standard neural networks (convolutional or vanilla) have one major shortcoming when compared to RNNs - they cannot reason about previous inputs to inform later ones. You cannot solve some machine learning problems without some kind of memory of past inputs.&lt;/p&gt;

&lt;p&gt;For example, you might run into a problem when you have some video frames of a ball moving and want to predict the direction of the ball. The way a standard neural network sees the problem is: you have a ball in one image and then you have a ball in another image. It does not have a mechanism for connecting these two images as a sequence. Standard neural networks cannot connect two separate images of the ball to the concept of “the ball is moving.” All it sees is that there is a ball in the image #1 and that there's a ball in the image #2, but network outputs are separate.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2F845ux8fgtlezz4yhgueo.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2F845ux8fgtlezz4yhgueo.png" alt="Convolutional Neural Network prediction" width="800" height="441"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;Compare this to the RNN, which remembers the last frames and can use that to inform its next prediction.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Fhrn0k2pnr5yr1fvpfe5g.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Fhrn0k2pnr5yr1fvpfe5g.png" alt="Recurrent Neural Network prediction" width="800" height="441"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h3&gt;
  
  
  LSTM vs RNN
&lt;/h3&gt;

&lt;p&gt;Typical RNNs can't memorize long sequences. The effect called “vanishing gradients” happens during the backpropagation phase of the RNN cell network. The gradients of cells that carry information from the start of a sequence goes through matrix multiplications by small numbers and reach close to 0 in long sequences. In other words - information at the start of the sequence has almost no effect at the end of the sequence.&lt;/p&gt;

&lt;p&gt;You can see that illustrated in the Recurrent Neural Network example. Given long enough sequence, the information from the first element of the sequence has no impact on the output of the last element of the sequence.&lt;/p&gt;

&lt;p&gt;LSTM is an RNN architecture that can memorize long sequences - up to 100 s of elements in a sequence. LSTM has a memory gating mechanism that allows the long term memory to continue flowing into the LSTM cells.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Fbgw47010k1mzhceid2dx.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media2.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Fbgw47010k1mzhceid2dx.png" alt="Long Short Term Memory cell" width="800" height="372"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;h2&gt;
  
  
  Text generation with PyTorch
&lt;/h2&gt;

&lt;p&gt;You will train a joke text generator using LSTM networks in PyTorch and follow the best practices. Start by creating a new folder where you'll store the code:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;&lt;span class="nv"&gt;$ &lt;/span&gt;&lt;span class="nb"&gt;mkdir &lt;/span&gt;text-generation
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Model
&lt;/h3&gt;

&lt;p&gt;To create an LSTM model, create a file &lt;code&gt;model.py&lt;/code&gt; in the &lt;code&gt;text-generation&lt;/code&gt; folder with the following content:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Module&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;super&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm_size&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding_dim&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;128&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="mi"&gt;3&lt;/span&gt;

        &lt;span class="n"&gt;n_vocab&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;uniq_words&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;num_embeddings&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;n_vocab&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;embedding_dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;embedding_dim&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;LSTM&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;input_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;hidden_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
            &lt;span class="n"&gt;dropout&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.2&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;fc&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Linear&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm_size&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;n_vocab&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;forward&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;prev_state&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;embed&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;embedding&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;lstm&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;embed&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;prev_state&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;fc&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;output&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;init_state&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sequence_length&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;return &lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sequence_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm_size&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt;
                &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zeros&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;num_layers&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;sequence_length&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;lstm_size&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This is a standard looking PyTorch model. &lt;code&gt;Embedding&lt;/code&gt; layer&lt;br&gt;
converts word indexes to word vectors. &lt;code&gt;LSTM&lt;/code&gt; is the main learnable part of the network - PyTorch implementation has the gating mechanism implemented inside the &lt;code&gt;LSTM&lt;/code&gt; cell that can learn long sequences of data.&lt;/p&gt;

&lt;p&gt;As described in the earlier What is LSTM? section - RNNs and LSTMs have extra state information they carry between training episodes.&lt;/p&gt;

&lt;p&gt;&lt;code&gt;forward&lt;/code&gt; function has a &lt;code&gt;prev_state&lt;/code&gt; argument. This state is kept outside the model and passed manually.&lt;/p&gt;

&lt;p&gt;It also has &lt;code&gt;init_state&lt;/code&gt; function. Calling this at the start of every epoch to initializes the right shape of the state.&lt;/p&gt;
&lt;h3&gt;
  
  
  Dataset
&lt;/h3&gt;

&lt;p&gt;For this tutorial, we use Reddit clean jokes dataset to train the network. &lt;a href="https://raw.githubusercontent.com/amoudgl/short-jokes-dataset/master/data/reddit-cleanjokes.csv" rel="noopener noreferrer"&gt;Download (139KB)&lt;/a&gt; the dataset and put it in the &lt;code&gt;text-generation/data/&lt;/code&gt; folder.&lt;/p&gt;

&lt;p&gt;The dataset has 1623 jokes and looks like this:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;ID,Joke
1,What did the bartender say to the jumper cables? You better not try to start anything.
2,Don't you hate jokes about German sausage? They're the wurst!
3,Two artists had an art contest... It ended in a draw
…
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;To load the data into PyTorch, use PyTorch &lt;code&gt;Dataset&lt;/code&gt; class. Create a &lt;code&gt;dataset.py&lt;/code&gt; file with the following content:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;pandas&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;collections&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Counter&lt;/span&gt;

&lt;span class="k"&gt;class&lt;/span&gt; &lt;span class="nc"&gt;Dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;utils&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;data&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;Dataset&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__init__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
        &lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt;
    &lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;args&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;load_words&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;uniq_words&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;get_uniq_words&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;index_to_word&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;uniq_words&lt;/span&gt;&lt;span class="p"&gt;)}&lt;/span&gt;
        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;word_to_index&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;{&lt;/span&gt;&lt;span class="n"&gt;word&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;index&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;word&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;uniq_words&lt;/span&gt;&lt;span class="p"&gt;)}&lt;/span&gt;

        &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;words_indexes&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;word_to_index&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;load_words&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;train_df&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;pd&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;read_csv&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;data/reddit-cleanjokes.csv&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;text&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;train_df&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Joke&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;].&lt;/span&gt;&lt;span class="nb"&gt;str&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;cat&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;sep&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;get_uniq_words&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;word_counts&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Counter&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;sorted&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;word_counts&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;key&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;word_counts&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;get&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;reverse&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="bp"&gt;True&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__len__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;words_indexes&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;-&lt;/span&gt; &lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sequence_length&lt;/span&gt;

    &lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;__getitem__&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="nf"&gt;return &lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;
            &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;words_indexes&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sequence_length&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt;
            &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;words_indexes&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt;&lt;span class="n"&gt;index&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="n"&gt;self&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sequence_length&lt;/span&gt;&lt;span class="o"&gt;+&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]),&lt;/span&gt;
        &lt;span class="p"&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;This &lt;code&gt;Dataset&lt;/code&gt; inherits from the PyTorch's &lt;code&gt;torch.utils.data.Dataset&lt;/code&gt; class and defines two important methods &lt;code&gt;__len__&lt;/code&gt; and &lt;code&gt;__getitem__&lt;/code&gt;. Read more about how &lt;code&gt;Dataset&lt;/code&gt; classes work in PyTorch &lt;a href="https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class" rel="noopener noreferrer"&gt;Data loading tutorial&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;code&gt;load_words&lt;/code&gt; function loads the dataset. Unique words are calculated in the dataset to define the size of the network's vocabulary and embedding size. &lt;code&gt;index_to_word&lt;/code&gt; and &lt;code&gt;word_to_index&lt;/code&gt; converts words to number indexes and visa versa.&lt;/p&gt;

&lt;p&gt;This is part of the process is &lt;em&gt;tokenization&lt;/em&gt;. In the future, &lt;a href="https://github.com/pytorch/text" rel="noopener noreferrer"&gt;torchtext&lt;/a&gt; team plan to&lt;br&gt;
improve this part, but they are re-designing it and the new API is too unstable for this tutorial today.&lt;/p&gt;
&lt;h3&gt;
  
  
  Training
&lt;/h3&gt;

&lt;p&gt;Create a &lt;code&gt;train.py&lt;/code&gt; file and define a &lt;code&gt;train&lt;/code&gt; function.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;argparse&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;
&lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;numpy&lt;/span&gt; &lt;span class="k"&gt;as&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;optim&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;torch.utils.data&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;DataLoader&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Model&lt;/span&gt;
&lt;span class="kn"&gt;from&lt;/span&gt; &lt;span class="n"&gt;dataset&lt;/span&gt; &lt;span class="kn"&gt;import&lt;/span&gt; &lt;span class="n"&gt;Dataset&lt;/span&gt;

&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="n"&gt;dataloader&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;DataLoader&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;batch_size&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;criterion&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;CrossEntropyLoss&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
    &lt;span class="n"&gt;optimizer&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;optim&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;Adam&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parameters&lt;/span&gt;&lt;span class="p"&gt;(),&lt;/span&gt; &lt;span class="n"&gt;lr&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mf"&gt;0.001&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;max_epochs&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;state_h&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;init_state&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;sequence_length&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

        &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;enumerate&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataloader&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
            &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;zero_grad&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

            &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_h&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_c&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_h&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_c&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
            &lt;span class="n"&gt;loss&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;criterion&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;transpose&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="mi"&gt;2&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;y&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

            &lt;span class="n"&gt;state_h&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;state_h&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;detach&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
            &lt;span class="n"&gt;state_c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;state_c&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;detach&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

            &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;backward&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
            &lt;span class="n"&gt;optimizer&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;step&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

            &lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;({&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;epoch&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;epoch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;batch&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;batch&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;loss&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;:&lt;/span&gt; &lt;span class="n"&gt;loss&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;item&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt; &lt;span class="p"&gt;})&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Use PyTorch &lt;code&gt;DataLoader&lt;/code&gt; and &lt;code&gt;Dataset&lt;/code&gt; abstractions to load the jokes data.&lt;/p&gt;

&lt;p&gt;Use &lt;code&gt;CrossEntropyLoss&lt;/code&gt; as a loss function and &lt;code&gt;Adam&lt;/code&gt; as an optimizer with default params. You can tweak it later.&lt;/p&gt;

&lt;p&gt;In his famous &lt;a href="https://karpathy.github.io/2019/04/25/recipe/" rel="noopener noreferrer"&gt;post&lt;/a&gt; Andrew Karpathy also recommends keeping this part simple at first.&lt;/p&gt;

&lt;h3&gt;
  
  
  Text generation
&lt;/h3&gt;

&lt;p&gt;Add &lt;code&gt;predict&lt;/code&gt; function to the &lt;code&gt;train.py&lt;/code&gt; file:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="k"&gt;def&lt;/span&gt; &lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_words&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;100&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
    &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;eval&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

    &lt;span class="n"&gt;words&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;split&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt; &lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
    &lt;span class="n"&gt;state_h&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_c&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;init_state&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

    &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;i&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="nf"&gt;range&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;next_words&lt;/span&gt;&lt;span class="p"&gt;):&lt;/span&gt;
        &lt;span class="n"&gt;x&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;tensor&lt;/span&gt;&lt;span class="p"&gt;([[&lt;/span&gt;&lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;word_to_index&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;w&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt; &lt;span class="k"&gt;for&lt;/span&gt; &lt;span class="n"&gt;w&lt;/span&gt; &lt;span class="ow"&gt;in&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;i&lt;/span&gt;&lt;span class="p"&gt;:]]])&lt;/span&gt;
        &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_h&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_c&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nf"&gt;model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;x&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;state_h&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;state_c&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;

        &lt;span class="n"&gt;last_word_logits&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;y_pred&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;][&lt;/span&gt;&lt;span class="o"&gt;-&lt;/span&gt;&lt;span class="mi"&gt;1&lt;/span&gt;&lt;span class="p"&gt;]&lt;/span&gt;
        &lt;span class="n"&gt;p&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;torch&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;nn&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;functional&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;softmax&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;last_word_logits&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;dim&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;0&lt;/span&gt;&lt;span class="p"&gt;).&lt;/span&gt;&lt;span class="nf"&gt;detach&lt;/span&gt;&lt;span class="p"&gt;().&lt;/span&gt;&lt;span class="nf"&gt;numpy&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
        &lt;span class="n"&gt;word_index&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;np&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;random&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;choice&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;len&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;last_word_logits&lt;/span&gt;&lt;span class="p"&gt;),&lt;/span&gt; &lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="n"&gt;p&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
        &lt;span class="n"&gt;words&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;append&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="n"&gt;index_to_word&lt;/span&gt;&lt;span class="p"&gt;[&lt;/span&gt;&lt;span class="n"&gt;word_index&lt;/span&gt;&lt;span class="p"&gt;])&lt;/span&gt;

    &lt;span class="k"&gt;return&lt;/span&gt; &lt;span class="n"&gt;words&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;h3&gt;
  
  
  Execute predictions
&lt;/h3&gt;

&lt;p&gt;Add the following code to &lt;code&gt;train.py&lt;/code&gt; file to execute the defined functions:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight python"&gt;&lt;code&gt;&lt;span class="n"&gt;parser&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;argparse&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nc"&gt;ArgumentParser&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;
&lt;span class="n"&gt;parser&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;add_argument&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--max-epochs&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;type&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;default&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;10&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;parser&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;add_argument&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--batch-size&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;type&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;default&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;256&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;parser&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;add_argument&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;--sequence-length&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="nb"&gt;type&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="nb"&gt;int&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;default&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="mi"&gt;4&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;args&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="n"&gt;parser&lt;/span&gt;&lt;span class="p"&gt;.&lt;/span&gt;&lt;span class="nf"&gt;parse_args&lt;/span&gt;&lt;span class="p"&gt;()&lt;/span&gt;

&lt;span class="n"&gt;dataset&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Dataset&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="n"&gt;model&lt;/span&gt; &lt;span class="o"&gt;=&lt;/span&gt; &lt;span class="nc"&gt;Model&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;

&lt;span class="nf"&gt;train&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;args&lt;/span&gt;&lt;span class="p"&gt;)&lt;/span&gt;
&lt;span class="nf"&gt;print&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="nf"&gt;predict&lt;/span&gt;&lt;span class="p"&gt;(&lt;/span&gt;&lt;span class="n"&gt;dataset&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;model&lt;/span&gt;&lt;span class="p"&gt;,&lt;/span&gt; &lt;span class="n"&gt;text&lt;/span&gt;&lt;span class="o"&gt;=&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="s"&gt;Knock knock. Whos there?&lt;/span&gt;&lt;span class="sh"&gt;'&lt;/span&gt;&lt;span class="p"&gt;))&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;Run the &lt;code&gt;train.py&lt;/code&gt; script with:&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;&lt;span class="nv"&gt;$ &lt;/span&gt;python train.py
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;You can see the loss along with the epochs. The model predicts the next 100 words after &lt;code&gt;Knock knock. Whos there?&lt;/code&gt; when the training finishes. By default, it runs for 10 epochs and takes around 15 mins to finish training.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight shell"&gt;&lt;code&gt;&lt;span class="o"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;'epoch'&lt;/span&gt;: 9, &lt;span class="s1"&gt;'batch'&lt;/span&gt;: 91, &lt;span class="s1"&gt;'loss'&lt;/span&gt;: 5.953955173492432&lt;span class="o"&gt;}&lt;/span&gt;
&lt;span class="o"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;'epoch'&lt;/span&gt;: 9, &lt;span class="s1"&gt;'batch'&lt;/span&gt;: 92, &lt;span class="s1"&gt;'loss'&lt;/span&gt;: 6.1532487869262695&lt;span class="o"&gt;}&lt;/span&gt;
&lt;span class="o"&gt;{&lt;/span&gt;&lt;span class="s1"&gt;'epoch'&lt;/span&gt;: 9, &lt;span class="s1"&gt;'batch'&lt;/span&gt;: 93, &lt;span class="s1"&gt;'loss'&lt;/span&gt;: 5.531163215637207&lt;span class="o"&gt;}&lt;/span&gt;
&lt;span class="o"&gt;[&lt;/span&gt;&lt;span class="s1"&gt;'Knock'&lt;/span&gt;, &lt;span class="s1"&gt;'knock.'&lt;/span&gt;, &lt;span class="s1"&gt;'Whos'&lt;/span&gt;, &lt;span class="s1"&gt;'there?'&lt;/span&gt;, &lt;span class="s1"&gt;'3)'&lt;/span&gt;, &lt;span class="s1"&gt;'moostard'&lt;/span&gt;, &lt;span class="s1"&gt;'bird'&lt;/span&gt;, &lt;span class="s1"&gt;'Book,'&lt;/span&gt;,
&lt;span class="s1"&gt;'What'&lt;/span&gt;, &lt;span class="s1"&gt;'when'&lt;/span&gt;, &lt;span class="s1"&gt;'when'&lt;/span&gt;, &lt;span class="s1"&gt;'the'&lt;/span&gt;, &lt;span class="s1"&gt;'Autumn'&lt;/span&gt;, &lt;span class="s1"&gt;'He'&lt;/span&gt;, &lt;span class="s1"&gt;'What'&lt;/span&gt;, &lt;span class="s1"&gt;'did'&lt;/span&gt;, &lt;span class="s1"&gt;'the'&lt;/span&gt;,
&lt;span class="s1"&gt;'psychologist?'&lt;/span&gt;, &lt;span class="s1"&gt;'And'&lt;/span&gt;, &lt;span class="s1"&gt;'look'&lt;/span&gt;, &lt;span class="s1"&gt;'any'&lt;/span&gt;, &lt;span class="s1"&gt;'jokes.'&lt;/span&gt;, &lt;span class="s1"&gt;'Do'&lt;/span&gt;, &lt;span class="s1"&gt;'by'&lt;/span&gt;, &lt;span class="s2"&gt;"Valentine's"&lt;/span&gt;,
&lt;span class="s1"&gt;'Because'&lt;/span&gt;, &lt;span class="s1"&gt;'I'&lt;/span&gt;, &lt;span class="s1"&gt;'papa'&lt;/span&gt;, &lt;span class="s1"&gt;'could'&lt;/span&gt;, &lt;span class="s1"&gt;'believe'&lt;/span&gt;, &lt;span class="s1"&gt;'had'&lt;/span&gt;, &lt;span class="s1"&gt;'a'&lt;/span&gt;, &lt;span class="s1"&gt;'call'&lt;/span&gt;, &lt;span class="s1"&gt;'decide'&lt;/span&gt;,
&lt;span class="s1"&gt;'elephants'&lt;/span&gt;, &lt;span class="s1"&gt;'it'&lt;/span&gt;, &lt;span class="s1"&gt;'my'&lt;/span&gt;, &lt;span class="s1"&gt;'eyes?'&lt;/span&gt;, &lt;span class="s1"&gt;'Why'&lt;/span&gt;, &lt;span class="s1"&gt;'you'&lt;/span&gt;, &lt;span class="s1"&gt;'different'&lt;/span&gt;, &lt;span class="s1"&gt;'know'&lt;/span&gt;, &lt;span class="s1"&gt;'in'&lt;/span&gt;,
&lt;span class="s1"&gt;'an'&lt;/span&gt;, &lt;span class="s1"&gt;'file'&lt;/span&gt;, &lt;span class="s1"&gt;'of'&lt;/span&gt;, &lt;span class="s1"&gt;'a'&lt;/span&gt;, &lt;span class="s1"&gt;'jungle?'&lt;/span&gt;, &lt;span class="s1"&gt;'Rock'&lt;/span&gt;, &lt;span class="s1"&gt;'-'&lt;/span&gt;, &lt;span class="s1"&gt;'and'&lt;/span&gt;, &lt;span class="s1"&gt;'might'&lt;/span&gt;, &lt;span class="s2"&gt;"It's"&lt;/span&gt;,
&lt;span class="s1"&gt;'every'&lt;/span&gt;, &lt;span class="s1"&gt;'out'&lt;/span&gt;, &lt;span class="s1"&gt;'say'&lt;/span&gt;, &lt;span class="s1"&gt;'when'&lt;/span&gt;, &lt;span class="s1"&gt;'to'&lt;/span&gt;, &lt;span class="s1"&gt;'an'&lt;/span&gt;, &lt;span class="s1"&gt;'ghost'&lt;/span&gt;, &lt;span class="s1"&gt;'however:'&lt;/span&gt;, &lt;span class="s1"&gt;'the'&lt;/span&gt;, &lt;span class="s1"&gt;'sex,'&lt;/span&gt;,
&lt;span class="s1"&gt;'in'&lt;/span&gt;, &lt;span class="s1"&gt;'his'&lt;/span&gt;, &lt;span class="s1"&gt;'hose'&lt;/span&gt;, &lt;span class="s1"&gt;'and'&lt;/span&gt;, &lt;span class="s1"&gt;'because'&lt;/span&gt;, &lt;span class="s1"&gt;'joke'&lt;/span&gt;, &lt;span class="s1"&gt;'the'&lt;/span&gt;, &lt;span class="s1"&gt;'month'&lt;/span&gt;, &lt;span class="s1"&gt;'25'&lt;/span&gt;, &lt;span class="s1"&gt;'The'&lt;/span&gt;,
&lt;span class="s1"&gt;'97'&lt;/span&gt;, &lt;span class="s1"&gt;'can'&lt;/span&gt;, &lt;span class="s1"&gt;'eggs.'&lt;/span&gt;, &lt;span class="s1"&gt;'was'&lt;/span&gt;, &lt;span class="s1"&gt;'dead'&lt;/span&gt;, &lt;span class="s1"&gt;'joke'&lt;/span&gt;, &lt;span class="s2"&gt;"I'm"&lt;/span&gt;, &lt;span class="s1"&gt;'a'&lt;/span&gt;, &lt;span class="s1"&gt;'want'&lt;/span&gt;, &lt;span class="s1"&gt;'is'&lt;/span&gt;, &lt;span class="s1"&gt;'you'&lt;/span&gt;,
&lt;span class="s1"&gt;'out'&lt;/span&gt;, &lt;span class="s1"&gt;'to'&lt;/span&gt;, &lt;span class="s1"&gt;'Sorry,'&lt;/span&gt;, &lt;span class="s1"&gt;'the'&lt;/span&gt;, &lt;span class="s1"&gt;'poet,'&lt;/span&gt;, &lt;span class="s1"&gt;'between'&lt;/span&gt;, &lt;span class="s1"&gt;'clean'&lt;/span&gt;, &lt;span class="s1"&gt;'Words'&lt;/span&gt;, &lt;span class="s1"&gt;'car'&lt;/span&gt;,
&lt;span class="s1"&gt;'his'&lt;/span&gt;, &lt;span class="s1"&gt;'wife'&lt;/span&gt;, &lt;span class="s1"&gt;'would'&lt;/span&gt;, &lt;span class="s1"&gt;'1000'&lt;/span&gt;, &lt;span class="s1"&gt;'and'&lt;/span&gt;, &lt;span class="s1"&gt;'Santa'&lt;/span&gt;, &lt;span class="s1"&gt;'oh'&lt;/span&gt;, &lt;span class="s1"&gt;'diving'&lt;/span&gt;, &lt;span class="s1"&gt;'machine?'&lt;/span&gt;,
&lt;span class="s1"&gt;'He'&lt;/span&gt;, &lt;span class="s1"&gt;'was'&lt;/span&gt;&lt;span class="o"&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;If you skipped to this part and want to run the code, here's a Github &lt;a href="https://github.com/closeheat/pytorch-lstm-text-generation-tutorial" rel="noopener noreferrer"&gt;repository&lt;/a&gt; you can clone.&lt;/p&gt;

&lt;h2&gt;
  
  
  Next steps
&lt;/h2&gt;

&lt;p&gt;Congratulations! You've written your first PyTorch LSTM network and generated some jokes.&lt;/p&gt;

&lt;p&gt;Here's what you can do next to improve the model:&lt;/p&gt;

&lt;ul&gt;
&lt;li&gt;Clean up the data by removing non-letter characters.&lt;/li&gt;
&lt;li&gt;Increase the model capacity by adding more &lt;code&gt;Linear&lt;/code&gt; or &lt;code&gt;LSTM&lt;/code&gt; layers.&lt;/li&gt;
&lt;li&gt;Split the dataset into train, test, and validation sets.&lt;/li&gt;
&lt;li&gt;Add checkpoints so you don't have to train the model every time you want to run prediction.&lt;/li&gt;
&lt;/ul&gt;

</description>
      <category>machinelearning</category>
      <category>python</category>
      <category>pytorch</category>
      <category>ai</category>
    </item>
  </channel>
</rss>
