DEV Community

Cover image for Building LSTMs with PyTorch and Lightning AI Part 1: First Steps with LSTMs
Rijul Rajesh
Rijul Rajesh

Posted on

Building LSTMs with PyTorch and Lightning AI Part 1: First Steps with LSTMs

In this article, we will explore how to implement an LSTM using PyTorch and Lightning.

For more details about LSTMs, there is a separate series of articles available here.


Imports

To begin, we first import the required modules.

import torch
import torch.nn as nn
import torch.nn.functional as F
Enter fullscreen mode Exit fullscreen mode

Introducing a New Optimizer

We also introduce a new optimizer:

from torch.optim import Adam
Enter fullscreen mode Exit fullscreen mode

Adam is used to fit the neural network to the data.

It works similarly to SGD, but in practice, Adam often converges faster and adapts the learning rate more effectively.


Lightning and Data Utilities

Next, we continue with the remaining imports:

import lightning as L
from torch.utils.data import TensorDataset, DataLoader
Enter fullscreen mode Exit fullscreen mode

Defining the LSTM Model

We define the neural network by creating a Lightning module.

class LSTMByHand(L.LightningModule):
    def __init__(self):
        # Create and initialize weight and bias tensors

    def lstm_unit(self, input_value, long_memory, short_memory):
        # LSTM computations

    def forward(self, input):
        # Forward pass through the unrolled LSTM

    def configure_optimizers(self):
        # Configure Adam optimizer

    def training_step(self, batch, batch_idx):
        # Compute loss and log training progress
Enter fullscreen mode Exit fullscreen mode

Initializing the Model

Now let’s implement the __init__ method.

This is where we initialize all weights and biases.

class LSTMByHand(L.LightningModule):
    def __init__(self):
        super().__init__()

        mean = torch.tensor(0.0)  # Mean of the normal distribution
        std = torch.tensor(1.0)   # Standard deviation

        # -------------------------
        # Forget Gate (l = "lr")
        # -------------------------
        self.wlr1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wlr2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.blr1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        # -------------------------
        # Input Gate (p = "pr")
        # -------------------------
        self.wpr1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wpr2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bpr1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        # -------------------------
        # Cell Candidate (p)
        # -------------------------
        self.wp1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wp2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bp1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)

        # -------------------------
        # Output Gate (o)
        # -------------------------
        self.wo1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.wo2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
        self.bo1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
Enter fullscreen mode Exit fullscreen mode

Why Use Normal Distribution?

Unlike earlier examples, we initialize weights using a normal distribution.

Before moving further, let’s understand what that means.

What is a Normal Distribution?

Imagine measuring the heights of a large group of people:

  • Most people are around the average height
  • Very tall and very short people are rare

When plotted, this forms a symmetric bell-shaped curve.

This is called a normal distribution.


Key Properties

  • The center represents the most common values
  • The curve is symmetric
  • The tails represent rare values

Mean and Standard Deviation

  • Mean → the average value
  • Standard deviation → how spread out the values are

Small vs Large Standard Deviation

Small Standard Deviation

  • Values are tightly clustered around the mean
  • Example: Class A scores mostly between 55–65

Large Standard Deviation

  • Values are widely spread
  • Example: Class B scores range from 20–90


In Our Code

We use:

  • Mean = 0
  • Standard deviation = 1

Also, all parameters have requires_grad=True, meaning they will be trained during backpropagation.

Next, we will explore the lstm_unit function and how the LSTM actually processes information step by step.

AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.

git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.

Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.

Give it a ⭐ star on Github

Top comments (0)