DEV Community

Cover image for Neural Networks with PyTorch and Lightning AI Part 2: Creating DataLoaders for Training
Rijul Rajesh
Rijul Rajesh

Posted on

Neural Networks with PyTorch and Lightning AI Part 2: Creating DataLoaders for Training

In the previous article, we started building our neural network with Lightning AI.

In this article, we will continue building the same neural network with a few modifications.

We will define almost the exact same class as before, but with a few important changes.

class BasicLightningTrain(L.LightningModule):
    def __init__(self):
        super().__init__()

        self.w00 = nn.Parameter(
            torch.tensor(1.7),
            requires_grad=False
        )
        self.b00 = nn.Parameter(
            torch.tensor(-0.85),
            requires_grad=False
        )
        self.w01 = nn.Parameter(
            torch.tensor(-40.8),
            requires_grad=False
        )

        self.w10 = nn.Parameter(
            torch.tensor(12.6),
            requires_grad=False
        )
        self.b10 = nn.Parameter(
            torch.tensor(0.0),
            requires_grad=False
        )
        self.w11 = nn.Parameter(
            torch.tensor(2.7),
            requires_grad=False
        )

        self.final_bias = nn.Parameter(
            torch.tensor(0.0),
            requires_grad=True
        )

        self.learning_rate = 0.01

    def forward(self, input):
        input_to_top_relu = (
            input * self.w00 + self.b00
        )

        top_relu_output = F.relu(
            input_to_top_relu
        )

        scaled_top_relu_output = (
            top_relu_output * self.w01
        )

        input_to_bottom_relu = (
            input * self.w10 + self.b10
        )

        bottom_relu_output = F.relu(
            input_to_bottom_relu
        )

        scaled_bottom_relu_output = (
            bottom_relu_output * self.w11
        )

        input_to_final_relu = (
            scaled_top_relu_output
            + scaled_bottom_relu_output
            + self.final_bias
        )

        output = F.relu(
            input_to_final_relu
        )

        return output
Enter fullscreen mode Exit fullscreen mode

What Changed?

The main difference is that we changed the parent class from nn.Module to LightningModule.

We also introduced a new variable called learning_rate:

self.learning_rate = 0.01
Enter fullscreen mode Exit fullscreen mode

For now, we are just setting this as a placeholder value.


The Current Problem

At the moment, when the dose is 0.5, the effectiveness becomes 17, which is much higher than the expected value.

This means we still need to optimize final_bias.


Creating the Training Data

Just like before, we first create the training inputs:

inputs = torch.tensor([0., 0.5, 1.])
Enter fullscreen mode Exit fullscreen mode

Next, we define the labels, which represent the known output values:

labels = torch.tensor([0., 1., 0.])
Enter fullscreen mode Exit fullscreen mode

Wrapping the Data with a DataLoader

Since we are using Lightning, there is one additional step.

We need to wrap the training data inside a DataLoader.

First, we combine the inputs and labels into a TensorDataset called dataset.

Then, we use that dataset to create a DataLoader.

dataset = TensorDataset(
    inputs,
    labels
)

dataloader = DataLoader(dataset)
Enter fullscreen mode Exit fullscreen mode

Why Use a DataLoader?

DataLoaders are useful because:

  • They make it easier to access data in batches.
  • They make it easy to shuffle data every epoch.
  • They allow us to work with smaller subsets of data for debugging.

Now that our data is loaded into a DataLoader, we are ready to optimize final_bias.

In the next article, we will explore how Lightning simplifies the optimization process.

AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.

git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.

Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.

Give it a ⭐ star on Github

Top comments (0)