DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

EMNIST in PyTorch

Buy Me a Coffee

*Memos:

EMNIST() can use EMNIST dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Required-Type:str). *"byclass", "bymerge", "balanced", "letters", "digits" or "mnist" can be set to it.
  • There is train argument(Optional-Default:True-Type:bool): *Memos:
    • train= must be used.
    • For split="byclass" and split="byclass", if it's True, train data(697,932 images) is used while if it's False, test data(116,323 images) is used.
    • For split="balanced", if it's True, train data(112,800 images) is used while if it's False, test data(188,00 images) is used.
    • For split="letters", if it's True, train data(124,800 images) is used while if it's False, test data(20,800 images) is used.
    • For split="digits", if it's True, train data(240,000 images) is used while if it's False, test data(40,000 images) is used.
    • For split="mnist", if it's True, train data(60,000 images) is used while if it's False, test data(10,000 images) is used.
  • There is transform argument(Optional-Default:None-Type:callable). *transform= must be used.
  • There is target_transform argument(Optional-Default:None-Type:callable). *target_transform= must be used.
  • There is download argument(Optional-Default:False-Type:bool): *Memos:
    • download= must be used.
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • You can manually download and extract the dataset(emnist-byclass-train-images-idx3-ubyte.gz, emnist-byclass-train-labels-idx1-ubyte.gz, emnist-byclass-test-images-idx3-ubyte.gz, emnist-byclass-test-labels-idx1-ubyte.gz, etc) from here to data/EMNIST/raw/.
  • There is the bug which the images are flipped and rotated 90 degrees anticlockwise by default so they should be transformed.
from torchvision.datasets import EMNIST

byclass_train_data = EMNIST(
    root="data",
    split="byclass"
)

byclass_train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

byclass_test_data = EMNIST(
    root="data",
    split="byclass",
    train=False
)

balanced_train_data = EMNIST(
    root="data",
    split="balanced",
    train=True
)

balanced_test_data = EMNIST(
    root="data",
    split="balanced",
    train=False
)

letters_train_data = EMNIST(
    root="data",
    split="letters",
    train=True
)

letters_test_data = EMNIST(
    root="data",
    split="letters",
    train=False
)

digits_train_data = EMNIST(
    root="data",
    split="digits",
    train=True
)

digits_test_data = EMNIST(
    root="data",
    split="digits",
    train=False
)

mnist_train_data = EMNIST(
    root="data",
    split="mnist",
    train=True
)

mnist_test_data = EMNIST(
    root="data",
    split="mnist",
    train=False
)

len(byclass_train_data), len(byclass_test_data)
# (697932, 116323)

len(balanced_train_data), len(balanced_test_data)
# (112800, 18800)

len(letters_train_data), len(letters_test_data)
# (124800, 20800)

len(digits_train_data), len(digits_test_data)
# (240000, 40000)

len(mnist_train_data), len(mnist_test_data)
# (60000, 10000)

byclass_train_data
# Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train

byclass_train_data.root
# 'data'

byclass_train_data.split
# 'byclass'

byclass_train_data.train
# True

print(byclass_train_data.transform)
# None

print(byclass_train_data.target_transform)
# None

byclass_train_data.download
# <bound method EMNIST.download of Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train>

len(byclass_train_data.classes), byclass_train_data.classes
# (62,
#  ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
#   'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
#   'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
#   'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
#   'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'])

byclass_train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 35)

byclass_train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 36)

byclass_train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 6)

byclass_train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)

byclass_train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 22)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in zip(range(1, 11), data):
        plt.subplot(2, 5, i)
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout()
    plt.show()

show_images(data=byclass_train_data, main_title="byclass_train_data")
show_images(data=byclass_test_data, main_title="byclass_test_data")

show_images(data=balanced_train_data, main_title="balanced_train_data")
show_images(data=balanced_test_data, main_title="balanced_test_data")

show_images(data=letters_train_data, main_title="letters_train_data")
show_images(data=letters_test_data, main_title="letters_test_data")

show_images(data=digits_train_data, main_title="digits_train_data")
show_images(data=digits_test_data, main_title="digits_test_data")

show_images(data=mnist_train_data, main_title="mnist_train_data")
show_images(data=mnist_test_data, main_title="mnist_test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

Image description

Image description

from torchvision.datasets import EMNIST
from torchvision.transforms import v2

tran = v2.Compose([v2.RandomHorizontalFlip(p=1.0),
                   v2.RandomRotation(degrees=(90, 90))])

byclass_train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=tran
)

byclass_test_data = EMNIST(
    root="data",
    split="byclass",
    train=False,
    transform=tran
)

balanced_train_data = EMNIST(
    root="data",
    split="balanced",
    train=True,
    transform=tran
)

balanced_test_data = EMNIST(
    root="data",
    split="balanced",
    train=False,
    transform=tran
)

letters_train_data = EMNIST(
    root="data",
    split="letters",
    train=True,
    transform=tran
)

letters_test_data = EMNIST(
    root="data",
    split="letters",
    train=False,
    transform=tran
)

digits_train_data = EMNIST(
    root="data",
    split="digits",
    train=True,
    transform=tran
)

digits_test_data = EMNIST(
    root="data",
    split="digits",
    train=False,
    transform=tran
)

mnist_train_data = EMNIST(
    root="data",
    split="mnist",
    train=True,
    transform=tran
)

mnist_test_data = EMNIST(
    root="data",
    split="mnist",
    train=False,
    transform=tran
)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in zip(range(1, 11), data):
        plt.subplot(2, 5, i)
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout()
    plt.show()

show_images(data=byclass_train_data, main_title="byclass_train_data")
show_images(data=byclass_test_data, main_title="byclass_test_data")

show_images(data=balanced_train_data, main_title="balanced_train_data")
show_images(data=balanced_test_data, main_title="balanced_test_data")

show_images(data=letters_train_data, main_title="letters_train_data")
show_images(data=letters_test_data, main_title="letters_test_data")

show_images(data=digits_train_data, main_title="digits_train_data")
show_images(data=digits_test_data, main_title="digits_test_data")

show_images(data=mnist_train_data, main_title="mnist_train_data")
show_images(data=mnist_test_data, main_title="mnist_test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

Image description

Image description

Image of AssemblyAI tool

Challenge Submission: SpeechCraft - AI-Powered Speech Analysis for Better Communication

SpeechCraft is an advanced real-time speech analytics platform that transforms spoken words into actionable insights. Using cutting-edge AI technology from AssemblyAI, it provides instant transcription while analyzing multiple dimensions of speech performance.

Read full post

Top comments (0)

Sentry image

See why 4M developers consider Sentry, “not bad.”

Fixing code doesn’t have to be the worst part of your day. Learn how Sentry can help.

Learn more

👋 Kindness is contagious

Explore a sea of insights with this enlightening post, highly esteemed within the nurturing DEV Community. Coders of all stripes are invited to participate and contribute to our shared knowledge.

Expressing gratitude with a simple "thank you" can make a big impact. Leave your thanks in the comments!

On DEV, exchanging ideas smooths our way and strengthens our community bonds. Found this useful? A quick note of thanks to the author can mean a lot.

Okay