DEV Community

Gavi Narra
Gavi Narra

Posted on

Training a Simple Neural Network in PyTorch and Integrating with Gradio for MNIST Digit Recognition

Introduction

In this article, we will walk through the steps of training a simple neural network on the MNIST dataset using PyTorch and then deploying it with Gradio for interactive predictions. The MNIST dataset is a popular dataset in the field of machine learning that consists of 70,000 28x28 grayscale images of handwritten digits.

Training a Neural Network with PyTorch

PyTorch is an open-source deep learning framework developed by Facebook's artificial intelligence research group. It provides a wide range of functionalities for building and training neural networks.

Step 1: Import necessary libraries

First, we need to import PyTorch, torchvision (a package with popular datasets, model architectures, and common image transformations), and some other necessary libraries:

import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
Enter fullscreen mode Exit fullscreen mode

Step 2: Load the dataset

Next, we load the MNIST dataset. We'll use torchvision's built-in functionality to do this. We also apply transformations to normalize the data:

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
Enter fullscreen mode Exit fullscreen mode

Step 3: Define the network

We'll define a simple feed-forward neural network with one hidden layer:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()
Enter fullscreen mode Exit fullscreen mode

Step 4: Define the loss function and optimizer

We'll use CrossEntropyLoss for our loss function and SGD for our optimizer:

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
Enter fullscreen mode Exit fullscreen mode

Step 5: Train the network

Now we're ready to train our network:

for epoch in range(10):  
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f'Epoch {epoch+1}, loss: {running_loss/len(trainloader)}')
print('Finished Training')
Enter fullscreen mode Exit fullscreen mode

Deploying with Gradio

Gradio is an open-source library for creating customizable UI components around your ML models. It allows us to demonstrate a model’s functionality in an intuitive manner.

Step 1: Install Gradio

!pip install gradio
Enter fullscreen mode Exit fullscreen mode

Step 2: Import Gradio and define the prediction function

import gradio as gr

def predict(image):
    image = image.reshape(1, 1, 28, 28)
    image = torch.from_numpy(image).float()
    output = net(image)
    _, predicted = torch.max(output.data, 1)
    return predicted.item()

Enter fullscreen mode Exit fullscreen mode

In the predict function, we take the input image, reshape it to match our model's expected input shape, convert it to a torch tensor, pass it through our model to get the output, and then return the predicted digit.

Step 3: Define the Gradio interface

Now, we define the interface for our model. We'll use an 'Image' input interface and a 'Label' output interface:

iface = gr.Interface(
    fn=predict,
    inputs=gr.inputs.Image(shape=(28, 28), invert_colors=True, source="canvas"),
    outputs="label",
    interpretation="default"
)
Enter fullscreen mode Exit fullscreen mode

The 'Image' input interface lets users draw an image with their mouse. We set invert_colors=True because the MNIST dataset consists of white digits on a black background, and by default, the Gradio image interface has a white background.

Step 4: Launch the interface

Finally, we launch the interface:

iface.launch()
Enter fullscreen mode Exit fullscreen mode

With this, you should see an interactive interface where you can draw a digit and see the prediction from your PyTorch model.

Conclusion

In this article, we saw how to train a simple neural network using PyTorch and then deploy it with Gradio for interactive predictions. This combination allows us to leverage the power of deep learning models in an easy-to-use and interpret manner.

Top comments (0)