DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

MNIST in PyTorch

Buy Me a Coffee

*Memos:

MNIST() can use MNIST dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is train(Optional-Default:True-Type:bool). *If it's True, train data(60,000 images) is used while if it's False, test data(10,000 images) is used.
  • The 3rd argument is transform(Optional-Default:None-Type:callable).
  • The 4th argument is target_transform(Optional-Default:None-Type:callable).
  • The 5th argument is download(Optional-Default:False-Type:bool): *Memos:
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • You can manually download and extract the dataset(train-images-idx3-ubyte.gz and train-labels-idx1-ubyte.gz, t10k-images-idx3-ubyte.gz and t10k-labels-idx1-ubyte.gz) from here to data/MNIST/raw/.
from torchvision.datasets import MNIST

train_data = MNIST(
    root="data"
)

train_data = MNIST(
    root="data",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

test_data = MNIST(
    root="data",
    train=False
)

len(train_data), len(test_data)
# (60000, 10000)

train_data
# Dataset MNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method MNIST.download of Dataset MNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train>

len(train_data.classes), train_data.classes
# (10,
#  ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
#   '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'])

train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 5)

train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 4)

train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 1)

train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 9)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in zip(range(1, 11), data):
        plt.subplot(2, 5, i)
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout()
    plt.show()

show_images(data=train_data, main_title="train_data")
show_images(data=test_data, main_title="test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image of Datadog

How to Diagram Your Cloud Architecture

Cloud architecture diagrams provide critical visibility into the resources in your environment and how they’re connected. In our latest eBook, AWS Solution Architects Jason Mimick and James Wenzel walk through best practices on how to build effective and professional diagrams.

Download the Free eBook

Top comments (0)

Postmark Image

Speedy emails, satisfied customers

Are delayed transactional emails costing you user satisfaction? Postmark delivers your emails almost instantly, keeping your customers happy and connected.

Sign up