In this article, we will explore how to create neural networks using PyTorch together with Lightning AI.
In the previous series of articles, we built a simple neural network for a small dataset.
The class we created contained:
- the code for the weights and biases
- the logic for running data through the neural network
- the code required to optimize the model using backpropagation
Even though the code worked as expected, there are still a few challenges we have to deal with.
1. Finding the Right Learning Rate
We had to manually figure out a suitable learning rate for gradient descent, which is not always easy.
If there were a tool that could help us find a good learning rate automatically, that would make things much easier.
2. Training Code Can Become Bulky
The code for training the neural network already started becoming quite large.
As neural networks grow in complexity, the training code can become harder to read and maintain.
Having a cleaner and easier way to organize training logic would be helpful.
3. Supporting GPUs and TPUs
In real-world scenarios, we often use hardware accelerators such as:
- GPUs (Graphics Processing Units)
- TPUs (Tensor Processing Units)
to train neural networks faster.
To support these devices, we would usually need to make changes to our code.
Why Use Lightning?
To solve these problems, we can combine PyTorch with Lightning.
Lightning helps simplify training code and makes it easier to work with larger models and hardware accelerators.
Importing the Required Modules
First, we will import the familiar modules that we used in the previous articles.
import torch # To create tensors
import torch.nn as nn
# Make weight and bias tensors part of the neural network
import torch.nn.functional as F
# For activation functions
from torch.optim import SGD
# For fitting the neural network to the data
Importing Lightning
Now, let us import Lightning.
import lightning as L
Working with Larger Datasets
We will also import a few utilities that make it easier to work with larger datasets.
from torch.utils.data import (
TensorDataset,
DataLoader
)
Graphing Results
We will continue using Matplotlib and Seaborn to visualize our outputs.
import matplotlib.pyplot as plt
import seaborn as sns
Creating the Neural Network
To create a neural network in Lightning, we do almost the same thing as before.
The main difference is that we change the parent class.
class BasicLightning(L.LightningModule):
Everything else remains mostly the same.
Training Example
To demonstrate the benefits of Lightning, we will reuse the same training example from the previous series.
This will be the class we use:
class BasicLightningTrain(L.LightningModule):
We will continue exploring how Lightning helps with this training process in the next article.
AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.
git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.
Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.
Give it a ⭐ star on Github

Top comments (0)