DEV Community

Cover image for Building Your First Neural Network: Image Classification with MNIST Fashion Dataset
Jay Codes
Jay Codes

Posted on

Building Your First Neural Network: Image Classification with MNIST Fashion Dataset

As technology advances, machine learning has become an essential tool for solving various real-world problems. One fascinating area of machine learning is neural networks, which take inspiration from the human brain's neural connections. In this article, we will guide you through the process of building your first neural network for image classification using the MNIST Fashion Dataset.

What is a Neural Network?

At its core, a neural network is a type of machine learning model that consists of interconnected artificial neurons organized into layers. Each neuron processes input data and passes the output to the next layer, gradually extracting meaningful patterns and relationships from the data. This allows the neural network to make predictions or decisions based on new, unseen data.

In the real world, neural networks have proven to be incredibly powerful and versatile tools. They excel in various applications, such as image recognition, natural language processing, speech recognition, recommendation systems, and more. Their ability to handle complex and non-linear relationships makes them invaluable for solving challenging problems across different domains.

The MNIST Fashion Dataset

To begin our journey into neural networks, we will use the MNIST Fashion Dataset. This dataset is included in the Keras library and is commonly used for training and evaluating deep learning models, particularly in the field of image classification.

The MNIST Fashion Dataset contains 70,000 grayscale images, each measuring 28x28 pixels. These images are divided into 60,000 training samples and 10,000 test samples. The dataset comprises 10 different classes, representing various fashion items such as T-shirts, trousers, pullovers, dresses, coats, sandals, shirts, sneakers, bags, and ankle boots.

Setting up the Environment

Before we delve into building the neural network, let's set up our development environment. We will use Python along with some powerful libraries and frameworks to create and train our model. Specifically, we'll work with TensorFlow, Keras, NumPy, Pandas, and Matplotlib.

# Importing the necessary libraries and frameworks
import tensorflow as tf 
from tensorflow import keras 
import numpy as np
import matplotlib.pyplot as plt 

# We will use a built-in dataset from Keras
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
print(train_images.shape) # Finding out the shape of the training data
print(train_images[0, 23, 23]) # Let's look at 1 pixel
print(train_labels[:10])  # Let's look at the first 10 training labels

# Let's create an array of the label names
class_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", 
               "Shirt", "Sneaker", "Bag", "Ankle Boot"]

# Using Matplotlib to visualize our data
plt.figure()
plt.imshow(train_images[8])
plt.colorbar()
plt.grid(False)
plt.show() 
Enter fullscreen mode Exit fullscreen mode

In the above code, we imported the necessary libraries and loaded the MNIST Fashion Dataset using the Keras library. We explored some basic information about the dataset, such as its shape and pixel values. Additionally, we visualized some sample images from the training set using Matplotlib, as seen below.

Preprocessing the Data

Before feeding the data into our neural network, we need to preprocess it to ensure that it is in a suitable format for training. The most crucial preprocessing step is scaling the pixel values to a range between 0 and 1. This scaling helps the neural network process the values more effectively.

# Preprocessing our data
train_images = train_images / 255.0
test_images = test_images / 255.0
Enter fullscreen mode Exit fullscreen mode

By dividing all pixel values by 255.0, we scale the pixel values to lie between 0 and 1, effectively normalizing the data. This step ensures that smaller values make it easier for the model to process the image data.

Building the Neural Network Architecture

With the data preprocessed, we can now proceed to construct our neural network architecture. Our model will consist of three layers:

Input Layer (Layer 1): This is the first layer of our neural network. We use the Flatten layer to reshape the 28x28 array of pixels into a vector of 784 neurons. Each pixel in the image will be associated with a neuron.

Hidden Layer (Layer 2): The second layer is a dense layer with 128 neurons. It is fully connected, meaning each neuron from the previous layer connects to each neuron in this layer. The ReLU (Rectified Linear Unit) activation function is used, which introduces non-linearity to the model, allowing it to learn complex patterns in the data.

Output Layer (Layer 3): This is the final layer of our neural network, consisting of 10 neurons. Each neuron represents the probability of the input image belonging to one of the ten different classes. We use the softmax activation function on this layer to calculate a probability distribution for each class. The output values of the neurons will be between 0 and 1, where 1 represents a high probability of the image belonging to a particular class.

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)), # Input layer 1
    keras.layers.Dense(128, activation='relu'),   # Hidden layer 2
    keras.layers.Dense(10, activation='softmax')  # Output layer 3
])
Enter fullscreen mode Exit fullscreen mode

In the above code snippet, we used the Keras Sequential model to create our neural network. We added three layers: the input layer with Flatten, the hidden layer with 128 neurons and ReLU activation, and the output layer with 10 neurons and softmax activation.

Compiling the Model

Before we can start training our model, we need to compile it. Compiling involves specifying the optimizer, loss function, and metrics to monitor during training.

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
Enter fullscreen mode Exit fullscreen mode

In this example, we use the adam optimizer, which is a popular choice for training neural networks. The sparse_categorical_crossentropy loss function is appropriate for our multi-class classification problem. The model's performance will be monitored using the accuracy metric, which measures the percentage of correct predictions.

Training the Model

With the model compiled, we can now proceed to train it using the training data. We specify the number of epochs, which determines how many times the model will go through the entire training dataset.

model.fit(train_images, train_labels, epochs=5)
Enter fullscreen mode Exit fullscreen mode

In this example, we train the model for 5 epochs. During each epoch, the model learns from the training data and updates its internal parameters to improve its predictions.

Evaluating the Model

After training, we evaluate the model's performance using the test dataset. This helps us understand how well the model generalizes to new, unseen data.

test_loss, test_acc =

 model.evaluate(test_images, test_labels, verbose=1)
print('Test accuracy:', test_acc)
Enter fullscreen mode Exit fullscreen mode

The evaluate function returns the test loss and test accuracy of the model. The test accuracy indicates the percentage of correctly classified images in the test dataset.

Making Predictions

Finally, we can use our trained model to make predictions based on new data. Let's predict the class of the first image from the test dataset.

predictions = model.predict(test_images)
print(class_names[np.argmax(predictions[1])])
Enter fullscreen mode Exit fullscreen mode

The model.predict function returns an array of probabilities for each class. We use np.argmax to find the index of the class with the highest probability, and then we use the class_names array to map this index to the corresponding fashion item label.

Interactively Predicting Images

As an exciting addition to our image classification model, we can now interactively select an image from the test dataset and view the model's prediction for that image. Let's explore this feature step-by-step:

Setting Up the Visualization

First, we'll set up the visualization to display the images and predictions in a visually appealing manner.

COLOR = 'white'
plt.rcParams['text.color'] = COLOR
plt.rcParams['axes.labelcolor'] = COLOR
Enter fullscreen mode Exit fullscreen mode

These lines of code set the text and axis label colors to white, creating a clean and readable visualization.

Creating the Prediction Function

Next, we define a function called predict that takes the trained model, an image, and its correct label as inputs. The function predicts the image's class and visualizes the image along with the expected label and the model's prediction.

def predict(model, image, correct_label):
    class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                   'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    prediction = model.predict(np.array([image]))
    predicted_class = class_names[np.argmax(prediction)]

    show_image(image, class_names[correct_label], predicted_class)
Enter fullscreen mode Exit fullscreen mode

The predict function uses the trained model to predict the class of the input image. The correct_label parameter represents the true label of the image, allowing us to display it alongside the model's prediction.

Displaying the Image

The show_image function takes an image, its expected label, and the model's prediction as inputs. It uses Matplotlib to display the image and relevant information.

def show_image(img, label, guess):
    plt.figure()
    plt.imshow(img, cmap=plt.cm.binary)
    plt.title("Expected: " + label)
    plt.xlabel("Guess: " + guess)
    plt.colorbar()
    plt.grid(False)
    plt.show()
Enter fullscreen mode Exit fullscreen mode

The function creates a figure, shows the grayscale image with a binary color map (black and white), adds a title displaying the expected label, and an xlabel showing the model's prediction. The colorbar indicates the pixel intensity values, and gridlines are removed for clarity.

Interactive Number Selection

To enable users to interactively pick a number from the test dataset, we create a function called get_number. This function asks the user to input a number until a valid choice between 0 and 1000 is provided.

def get_number():
    while True:
        num = input("Pick a number: ")
        if num.isdigit():
            num = int(num)
            if 0 <= num <= 1000:
                return int(num)
        else:
            print("Try again...")
Enter fullscreen mode Exit fullscreen mode

This interactive feature allows users to experience the neural network's predictions on different test images and gain insights into its performance.

Putting It All Together

Now, we can combine the functions to create an interactive experience for users. Users will be prompted to enter a number, and the corresponding image from the test dataset will be displayed along with its expected label and the model's prediction.

num = get_number()
image = test_images[num]
label = test_labels[num]
predict(model, image, label)
Enter fullscreen mode Exit fullscreen mode

This final part of the code enables users to explore the model's predictions and gain a deeper understanding of its performance on different fashion items, as seen below.

Conclusion

Weldone! We have successfully built your first neural network for image classification using the MNIST Fashion Dataset. I'm glad you completed this journey with me. You've learned the basics of neural networks: how to preprocess data, construct a neural network architecture, and train the model on the data.

Neural networks are a powerful tool in machine learning, and understanding their inner workings opens up a world of possibilities for solving complex real-world problems.

In the next steps of your machine learning journey, you can experiment with different model architectures, hyperparameters, and datasets to further enhance your understanding and skills in the fascinating field of deep learning.

Remember, practice makes perfect, so keep exploring, learning, and building!

Top comments (0)