DEV Community

Cover image for Building LSTMs with PyTorch and Lightning AI Part 5: Improving Predictions Through Training
Rijul Rajesh
Rijul Rajesh

Posted on

Building LSTMs with PyTorch and Lightning AI Part 5: Improving Predictions Through Training

In the previous article, we ran our model and checked how accurate its predictions were.

In this article, we will train the model.

First, we create the training data.

inputs = torch.tensor([
    [0., 0.5, 0.25, 1.],
    [1., 0.5, 0.25, 1.]
])
Enter fullscreen mode Exit fullscreen mode

These represent the stock prices for Days 1 through 4 for both companies.

Next, we create the labels, which are the values we want the LSTM to predict.

labels = torch.tensor([0., 1.])
Enter fullscreen mode Exit fullscreen mode

Here, we want the LSTM to predict:

  • 0 for Company A
  • 1 for Company B

Now we combine the inputs and labels into a TensorDataset called dataset.

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)
Enter fullscreen mode Exit fullscreen mode

As we discussed in previous articles, DataLoaders are useful because:

  • They make it easy to access the data in batches.
  • They can shuffle the data at the beginning of each epoch.
  • They allow us to use a small subset of the data when we want to quickly debug the training process.

Next, we create a Lightning trainer.

trainer = L.Trainer(max_epochs=2000)
Enter fullscreen mode Exit fullscreen mode

Here, we tell Lightning to train the model for a maximum of 2,000 epochs.

During training, backpropagation is used to optimize all the trainable weights and biases in the LSTM.

To begin training, we simply call the trainer's fit() method.

trainer.fit(model, train_dataloaders=dataloader)
Enter fullscreen mode Exit fullscreen mode

Once training is complete, we can print the predictions just as we did before.

print("\nComparing observed and predicted values")

print(
    "Company A: Observed = 0, Predicted =",
    model(torch.tensor([0., 0.5, 0.25, 1.])).detach()
)

print(
    "Company B: Observed = 1, Predicted =",
    model(torch.tensor([1., 0.5, 0.25, 1.])).detach()
)
Enter fullscreen mode Exit fullscreen mode

This produces the following output:

Comparing observed and predicted values

Company A: Observed = 0, Predicted = tensor(0.0003)
Company B: Observed = 1, Predicted = tensor(0.9287)
Enter fullscreen mode Exit fullscreen mode

As you can see, the predictions have improved significantly after training. The model now produces values that are much closer to the expected outputs.

In the next article, we will explore TensorBoard to analyze what happened during training.

AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.

git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.

Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.

Give it a ⭐ star on Github

Top comments (0)