DEV Community

Cover image for Building LSTMs with PyTorch and Lightning AI Part 7: Resuming Training with Checkpoints
Rijul Rajesh
Rijul Rajesh

Posted on

Building LSTMs with PyTorch and Lightning AI Part 7: Resuming Training with Checkpoints

In the previous article, we used TensorBoard to analyze the training process. Based on the graphs, we concluded that the model had not fully converged and could benefit from additional training epochs.

Let's continue with that in this article.

One of the advantages of Lightning is that we can continue training without starting from scratch.

This is possible because Lightning automatically saves checkpoints during training.

These checkpoints allow us to resume training from where we left off and continue optimizing the model.

Getting the Checkpoint

First, we need to find the path to the latest checkpoint.

path_to_best_checkpoint = trainer.checkpoint_callback.best_model_path
Enter fullscreen mode Exit fullscreen mode

Here, best_model_path gives us the path to the latest checkpoint that Lightning has saved.


Increasing the Number of Epochs

Now we create a new trainer and increase the maximum number of epochs to 3000.

trainer = L.Trainer(max_epochs=3000)
Enter fullscreen mode Exit fullscreen mode

Instead of starting from the beginning, we resume training from the saved checkpoint.

trainer.fit(
    model,
    train_dataloaders=dataloader,
    ckpt_path=path_to_best_checkpoint
)
Enter fullscreen mode Exit fullscreen mode

By specifying ckpt_path, Lightning continues training from the saved checkpoint instead of initializing the model again.


Checking the Updated Predictions

Now let's print the predictions once again.

print("\nComparing observed and predicted values")

print(
    "Company A: Observed = 0, Predicted =",
    model(torch.tensor([0., 0.5, 0.25, 1.])).detach()
)

print(
    "Company B: Observed = 1, Predicted =",
    model(torch.tensor([1., 0.5, 0.25, 1.])).detach()
)
Enter fullscreen mode Exit fullscreen mode

This produces the following output:

Comparing observed and predicted values

Company A: Observed = 0, Predicted = tensor(0.0009)
Company B: Observed = 1, Predicted = tensor(0.9423)
Enter fullscreen mode Exit fullscreen mode

The prediction for Company A has moved even closer to the target value of 0.

Similarly, the prediction for Company B has moved closer to the target value of 1.


Comparing the TensorBoard Graphs

Let's look at TensorBoard again.

Company A

After training for more epochs, the prediction has moved closer to the desired value of 0.

Company B

Similarly, the prediction for Company B has moved closer to the desired value of 1.

Training Loss

The train_loss graph also shows that the loss has decreased further after the additional training.


Although the model has improved, we can still train it for more epochs to refine the predictions even further.

In the next article, we will continue improving the model and also explore how Lightning can simplify LSTM implementations using PyTorch's built-in nn.LSTM() module.

AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.

git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.

Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.

Give it a ⭐ star on Github

Top comments (0)