DEV Community

Cover image for PyTorch + Lightning AI for Neural Networks Part 1: Building Better Workflows
Rijul Rajesh
Rijul Rajesh

Posted on

PyTorch + Lightning AI for Neural Networks Part 1: Building Better Workflows

In this article, we will explore how to create neural networks using PyTorch together with Lightning AI.

In the previous series of articles, we built a simple neural network for a small dataset.

The class we created contained:

  • the code for the weights and biases
  • the logic for running data through the neural network
  • the code required to optimize the model using backpropagation

Even though the code worked as expected, there are still a few challenges we have to deal with.

1. Finding the Right Learning Rate

We had to manually figure out a suitable learning rate for gradient descent, which is not always easy.

If there were a tool that could help us find a good learning rate automatically, that would make things much easier.

2. Training Code Can Become Bulky

The code for training the neural network already started becoming quite large.

As neural networks grow in complexity, the training code can become harder to read and maintain.

Having a cleaner and easier way to organize training logic would be helpful.

3. Supporting GPUs and TPUs

In real-world scenarios, we often use hardware accelerators such as:

  • GPUs (Graphics Processing Units)
  • TPUs (Tensor Processing Units)

to train neural networks faster.

To support these devices, we would usually need to make changes to our code.


Why Use Lightning?

To solve these problems, we can combine PyTorch with Lightning.

Lightning helps simplify training code and makes it easier to work with larger models and hardware accelerators.


Importing the Required Modules

First, we will import the familiar modules that we used in the previous articles.

import torch  # To create tensors

import torch.nn as nn
# Make weight and bias tensors part of the neural network

import torch.nn.functional as F
# For activation functions

from torch.optim import SGD
# For fitting the neural network to the data
Enter fullscreen mode Exit fullscreen mode

Importing Lightning

Now, let us import Lightning.

import lightning as L
Enter fullscreen mode Exit fullscreen mode

Working with Larger Datasets

We will also import a few utilities that make it easier to work with larger datasets.

from torch.utils.data import (
    TensorDataset,
    DataLoader
)
Enter fullscreen mode Exit fullscreen mode

Graphing Results

We will continue using Matplotlib and Seaborn to visualize our outputs.

import matplotlib.pyplot as plt
import seaborn as sns
Enter fullscreen mode Exit fullscreen mode

Creating the Neural Network

To create a neural network in Lightning, we do almost the same thing as before.

The main difference is that we change the parent class.

class BasicLightning(L.LightningModule):
Enter fullscreen mode Exit fullscreen mode

Everything else remains mostly the same.


Training Example

To demonstrate the benefits of Lightning, we will reuse the same training example from the previous series.

This will be the class we use:

class BasicLightningTrain(L.LightningModule):
Enter fullscreen mode Exit fullscreen mode

We will continue exploring how Lightning helps with this training process in the next article.


AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.

git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.

Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.

Give it a ⭐ star on Github

Top comments (0)