In the previous article, we just saw how we can start using a more simplified version of LSTM via pytorch via nn.LSTM()
In this article, we will continue building the simplified LSTM and test how it performs.
Implementing the forward() Method
Let's start by implementing the forward() method.
def forward(self, input):
input_trans = input.view(len(input), 1)
lstm_out, temp = self.lstm(input_trans)
prediction = lstm_out[-1]
return prediction
In input_trans, we reshape the input so that there is one row for each data point, regardless of how many data points we have.
Next, we specify that the input should have one column, since each data point contains only a single feature.
This reshaped input is then passed to the LSTM.
The output is stored in lstm_out.
lstm_out contains the short-term memory values produced by each LSTM unit as the sequence is processed.
In our example, the sequence contains four input values, so the LSTM is unrolled four times and lstm_out contains four outputs.
Next, we extract the prediction from the final LSTM unit by selecting the last element in the sequence using the index -1.
Finally, we return this prediction.
Configuring the Optimizer
Next, let's implement configure_optimizers().
It is almost identical to the previous implementation, except that we increase the learning rate from its default value of 0.001 to 0.1.
This allows us to observe how the Adam optimizer converges to the optimal weights and biases.
def configure_optimizers(self):
return Adam(self.parameters(), lr=0.1)
Implementing training_step()
Finally, we implement training_step().
This method is exactly the same as before. It calculates the loss and logs the training progress.
def training_step(self, batch, batch_idx):
input_i, label_i = batch
output_i = self.forward(input_i[0])
loss = (output_i - label_i)**2
self.log("train_loss", loss)
if label_i == 0:
self.log("out_0", output_i)
else:
self.log("out_1", output_i)
return loss
At this point, our model contains everything it needs:
__init__()forward()configure_optimizers()training_step()
Testing the Model
Let's run the model before training and check its predictions.
model = LightningLSTM()
print("\nComparing observed and predicted values")
print(
"Company A: Observed = 0, Predicted =",
model(torch.tensor([0., 0.5, 0.25, 1.])).detach()
)
print(
"Company B: Observed = 1, Predicted =",
model(torch.tensor([1., 0.5, 0.25, 1.])).detach()
)
This produces:
Comparing observed and predicted values
Company A: Observed = 0, Predicted = tensor([0.0131])
Company B: Observed = 1, Predicted = tensor([0.0102])
Training the Model
Now let's create a Lightning trainer.
trainer = L.Trainer(max_epochs=300, log_every_n_steps=2)
We train for 300 epochs and set log_every_n_steps to 2.
By default, Lightning logs every 50 steps, which is too infrequent for a small training run like this.
Next, we simply call fit().
trainer.fit(model, train_dataloaders=dataloader)
Checking the Results
Once training is complete, we check the predictions again.
Comparing observed and predicted values
Company A: Observed = 0, Predicted = tensor([0.0001])
Company B: Observed = 1, Predicted = tensor([0.9857])
The prediction for Company A is now very close to 0, and the prediction for Company B is very close to 1.
Analyzing the Training
Let's open TensorBoard once again.
Notice that the graphs have flattened out as the model converged toward the desired predictions.
With that, we have completed our exploration of LSTMs.
We built an LSTM from scratch, then implemented a much simpler version using PyTorch's built-in functionality. Along the way, we also learned how to analyze the training process using TensorBoard and how to use those insights to make training decisions.
In the next series of articles, we will explore how to implement word embeddings using PyTorch and Lightning AI.
AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.
git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.
Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.
Give it a ⭐ star on Github


Top comments (0)