DEV Community

Cover image for Building LSTMs with PyTorch and Lightning AI Part 4: Training Step and Initial Predictions
Rijul Rajesh
Rijul Rajesh

Posted on

Building LSTMs with PyTorch and Lightning AI Part 4: Training Step and Initial Predictions

In the previous article, we finished the LSTM cell, explored the forward method and the Adam optimizer for the model.

In this article, we will explore the training_step() function, and try to run the model without training.

The training_step() function takes a batch of training data from one of the two companies, along with the index of that batch.

It then uses the forward() function to make a prediction for that training example.

def training_step(self, batch, batch_idx):
    input_i, label_i = batch
    output_i = self.forward(input_i[0])
    loss = (output_i - label_i)**2
Enter fullscreen mode Exit fullscreen mode

Next, it calculates the loss, which is the squared residual between the predicted value and the observed value.

We can also log the loss to easily track how it changes during training.

Lightning provides the log() function for this purpose. It automatically stores the logs in a lightning_logs directory.

We can log other values as well, such as the predictions for Company A and Company B.

Finally, we return the loss.

def training_step(self, batch, batch_idx):
    input_i, label_i = batch
    output_i = self.forward(input_i[0])
    loss = (output_i - label_i)**2

    self.log("train_loss", loss)

    if label_i == 0:
        self.log("out_0", output_i)
    else:
        self.log("out_1", output_i)

    return loss
Enter fullscreen mode Exit fullscreen mode

So far, we have implemented the following:

  • Initialized the weight and bias tensors.
  • Implemented the LSTM calculations in lstm_unit().
  • Created the forward() method to perform a forward pass through the unrolled LSTM.
  • Configured the Adam optimizer using configure_optimizers().
  • Calculated and logged the training loss using training_step().

Now let's try using the model.

model = LSTMByHand()

print("\nComparing observed and predicted values")

print(
    "Company A: Observed = 0, Predicted =",
    model(torch.tensor([0., 0.5, 0.25, 1.])).detach()
)

print(
    "Company B: Observed = 1, Predicted =",
    model(torch.tensor([1., 0.5, 0.25, 1.])).detach()
)
Enter fullscreen mode Exit fullscreen mode

Here, we pass a tensor containing the stock prices for Days 1 through 4. The model then predicts the value for Day 5.

The model returns both the prediction and its associated computation graph. We call .detach() to remove the computation graph and retrieve only the prediction.

Running the code produces the following output:

Comparing observed and predicted values
Company A: Observed = 0, Predicted = tensor(-0.2321)
Company B: Observed = 1, Predicted = tensor(-0.2360)
Enter fullscreen mode Exit fullscreen mode

The prediction for Company A is reasonably close to the observed value.

However, the prediction for Company B is quite far from the expected value.

In the next article, we will train the model to improve these predictions.

AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.

git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.

Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.

Give it a ⭐ star on Github

Top comments (0)