DEV Community

Cover image for Neural Networks with PyTorch and Lightning AI Part 4: From Manual Training to Automated Training
Rijul Rajesh
Rijul Rajesh

Posted on

Neural Networks with PyTorch and Lightning AI Part 4: From Manual Training to Automated Training

In this article, we will continue with optimizing our neural network using Lightning.

Creating the Model and Trainer

First, we create the model and the trainer.

model = BasicLightningTrain()

trainer = L.Trainer(max_epochs=34)
Enter fullscreen mode Exit fullscreen mode

Here, we set the maximum number of epochs to 34.

We chose 34 because, from our earlier experiments, we know that 34 epochs are sufficient for fitting the training data.

Even if the number of epochs turns out to be insufficient, we do not need to start training from scratch.

Lightning allows us to continue training later by increasing the number of epochs and resuming from where we left off.


Finding a Better Learning Rate

Now that we have a trainer, we can use it to find an improved learning rate.

lr_find_results = trainer.tuner.lr_find(
    model,
    train_dataloaders=dataloader,
    min_lr=0.001,
    max_lr=1.0,
    early_stop_threshold=None
)
Enter fullscreen mode Exit fullscreen mode

Here, we pass:

  • The model
  • The training DataLoader
  • A minimum learning rate of 0.001
  • A maximum learning rate of 1.0
  • early_stop_threshold=None, which tells Lightning not to stop the search early

How lr_find() Works

The lr_find() function tests a range of candidate learning rates between the minimum and maximum values.

By setting early_stop_threshold=None, we allow Lightning to evaluate the entire range instead of stopping early.

The results are stored in:

lr_find_results
Enter fullscreen mode Exit fullscreen mode

Getting the Suggested Learning Rate

We can retrieve Lightning's suggested learning rate using the suggestion() method.

new_lr = lr_find_results.suggestion()

print(
    f"lr_find() suggests "
    f"{new_lr:.5f} "
    f"for the learning rate"
)
Enter fullscreen mode Exit fullscreen mode

After that, we can assign the suggested value to our model's learning_rate variable.

Now we have a learning rate that is likely to work better than our original placeholder value.


Training the Model

With the improved learning rate in place, we are ready to train the neural network.

trainer.fit(
    model,
    train_dataloaders=dataloader
)
Enter fullscreen mode Exit fullscreen mode

The fit() function requires:

  • The model
  • The training DataLoader

What Happens Inside fit()?

When we call fit(), Lightning automatically performs many of the steps that we previously had to write ourselves.

First, Lightning calls:

configure_optimizers()
Enter fullscreen mode Exit fullscreen mode

This creates and configures the SGD optimizer using the learning rate we specified.

Next, Lightning repeatedly calls:

training_step()
Enter fullscreen mode Exit fullscreen mode

to calculate the loss for each batch.

Behind the scenes, Lightning also performs the following operations automatically:

optimizer.zero_grad()
Enter fullscreen mode Exit fullscreen mode

This clears the gradients before the next optimization step.

loss.backward()
Enter fullscreen mode Exit fullscreen mode

This calculates the gradients.

optimizer.step()
Enter fullscreen mode Exit fullscreen mode

This updates the parameters by taking a step toward better values.

Lightning then repeats this process for every epoch that we requested.


Comparing This to Raw PyTorch

In our earlier PyTorch implementation, we had to manually write:

  • The training loops
  • Gradient calculations
  • Gradient resets
  • Parameter updates

With Lightning, most of that logic is handled automatically.

As a result, the training code becomes much shorter and easier to read.


Now that the model has been trained, the next step is to verify that final_bias was optimized correctly.

We will explore that, along with a few additional Lightning features, in the next article.

AI agents write code fast. They also silently remove logic, change behavior, and introduce bugs -- without telling you. You often find out in production.

git-lrc fixes this. It hooks into git commit and reviews every diff before it lands. 60-second setup. Completely free.

Any feedback or contributors are welcome! It's online, source-available, and ready for anyone to use.

Give it a ⭐ star on Github

Top comments (0)