In this article, we will explore how to implement an LSTM using PyTorch and Lightning.
For more details about LSTMs, there is a separate series of articles available here.
Imports
To begin, we first import the required modules.
import torch
import torch.nn as nn
import torch.nn.functional as F
Introducing a New Optimizer
We also introduce a new optimizer:
from torch.optim import Adam
Adam is used to fit the neural network to the data.
It works similarly to SGD, but in practice, Adam often converges faster and adapts the learning rate more effectively.
Lightning and Data Utilities
Next, we continue with the remaining imports:
import lightning as L
from torch.utils.data import TensorDataset, DataLoader
Defining the LSTM Model
We define the neural network by creating a Lightning module.
class LSTMByHand(L.LightningModule):
def __init__(self):
# Create and initialize weight and bias tensors
def lstm_unit(self, input_value, long_memory, short_memory):
# LSTM computations
def forward(self, input):
# Forward pass through the unrolled LSTM
def configure_optimizers(self):
# Configure Adam optimizer
def training_step(self, batch, batch_idx):
# Compute loss and log training progress
Initializing the Model
Now let’s implement the __init__ method.
This is where we initialize all weights and biases.
class LSTMByHand(L.LightningModule):
def __init__(self):
super().__init__()
mean = torch.tensor(0.0) # Mean of the normal distribution
std = torch.tensor(1.0) # Standard deviation
# -------------------------
# Forget Gate (l = "lr")
# -------------------------
self.wlr1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
self.wlr2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
self.blr1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
# -------------------------
# Input Gate (p = "pr")
# -------------------------
self.wpr1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
self.wpr2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
self.bpr1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
# -------------------------
# Cell Candidate (p)
# -------------------------
self.wp1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
self.wp2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
self.bp1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
# -------------------------
# Output Gate (o)
# -------------------------
self.wo1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
self.wo2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)
self.bo1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)
Why Use Normal Distribution?
Unlike earlier examples, we initialize weights using a normal distribution.
Before moving further, let’s understand what that means.
What is a Normal Distribution?
Imagine measuring the heights of a large group of people:
- Most people are around the average height
- Very tall and very short people are rare
When plotted, this forms a symmetric bell-shaped curve.
This is called a normal distribution.
Key Properties
- The center represents the most common values
- The curve is symmetric
- The tails represent rare values
Mean and Standard Deviation
- Mean → the average value
- Standard deviation → how spread out the values are
Small vs Large Standard Deviation
Small Standard Deviation
- Values are tightly clustered around the mean
- Example: Class A scores mostly between 55–65
Large Standard Deviation
- Values are widely spread
- Example: Class B scores range from 20–90
In Our Code
We use:
- Mean =
0 - Standard deviation =
1
Also, all parameters have requires_grad=True, meaning they will be trained during backpropagation.
Next, we will explore the lstm_unit function and how the LSTM actually processes information step by step.
AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.
git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.
Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.
Give it a ⭐ star on Github




Top comments (0)