In the previous article, we ran our model and checked how accurate its predictions were.
In this article, we will train the model.
First, we create the training data.
inputs = torch.tensor([
[0., 0.5, 0.25, 1.],
[1., 0.5, 0.25, 1.]
])
These represent the stock prices for Days 1 through 4 for both companies.
Next, we create the labels, which are the values we want the LSTM to predict.
labels = torch.tensor([0., 1.])
Here, we want the LSTM to predict:
-
0for Company A -
1for Company B
Now we combine the inputs and labels into a TensorDataset called dataset.
dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)
As we discussed in previous articles, DataLoaders are useful because:
- They make it easy to access the data in batches.
- They can shuffle the data at the beginning of each epoch.
- They allow us to use a small subset of the data when we want to quickly debug the training process.
Next, we create a Lightning trainer.
trainer = L.Trainer(max_epochs=2000)
Here, we tell Lightning to train the model for a maximum of 2,000 epochs.
During training, backpropagation is used to optimize all the trainable weights and biases in the LSTM.
To begin training, we simply call the trainer's fit() method.
trainer.fit(model, train_dataloaders=dataloader)
Once training is complete, we can print the predictions just as we did before.
print("\nComparing observed and predicted values")
print(
"Company A: Observed = 0, Predicted =",
model(torch.tensor([0., 0.5, 0.25, 1.])).detach()
)
print(
"Company B: Observed = 1, Predicted =",
model(torch.tensor([1., 0.5, 0.25, 1.])).detach()
)
This produces the following output:
Comparing observed and predicted values
Company A: Observed = 0, Predicted = tensor(0.0003)
Company B: Observed = 1, Predicted = tensor(0.9287)
As you can see, the predictions have improved significantly after training. The model now produces values that are much closer to the expected outputs.
In the next article, we will explore TensorBoard to analyze what happened during training.
AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.
git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.
Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.
Give it a ⭐ star on Github

Top comments (0)