DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

EMNIST in PyTorch

Buy Me a Coffee

*Memos:

EMNIST() can use EMNIST dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Required-Type:str). *"byclass", "bymerge", "balanced", "letters", "digits" or "mnist" can be set to it.
  • There is train argument(Optional-Default:True-Type:bool): *Memos:
    • train= must be used.
    • For split="byclass" and split="byclass", if it's True, train data(697,932 images) is used while if it's False, test data(116,323 images) is used.
    • For split="balanced", if it's True, train data(112,800 images) is used while if it's False, test data(188,00 images) is used.
    • For split="letters", if it's True, train data(124,800 images) is used while if it's False, test data(20,800 images) is used.
    • For split="digits", if it's True, train data(240,000 images) is used while if it's False, test data(40,000 images) is used.
    • For split="mnist", if it's True, train data(60,000 images) is used while if it's False, test data(10,000 images) is used.
  • There is transform argument(Optional-Default:None-Type:callable). *transform= must be used.
  • There is target_transform argument(Optional-Default:None-Type:callable). *target_transform= must be used.
  • There is download argument(Optional-Default:False-Type:bool): *Memos:
    • download= must be used.
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • You can manually download and extract the dataset(emnist-byclass-train-images-idx3-ubyte.gz, emnist-byclass-train-labels-idx1-ubyte.gz, emnist-byclass-test-images-idx3-ubyte.gz, emnist-byclass-test-labels-idx1-ubyte.gz, etc) from here to data/EMNIST/raw/.
  • There is the bug which the images are flipped and rotated 90 degrees anticlockwise by default so they should be transformed.
from torchvision.datasets import EMNIST

byclass_train_data = EMNIST(
    root="data",
    split="byclass"
)

byclass_train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

byclass_test_data = EMNIST(
    root="data",
    split="byclass",
    train=False
)

balanced_train_data = EMNIST(
    root="data",
    split="balanced",
    train=True
)

balanced_test_data = EMNIST(
    root="data",
    split="balanced",
    train=False
)

letters_train_data = EMNIST(
    root="data",
    split="letters",
    train=True
)

letters_test_data = EMNIST(
    root="data",
    split="letters",
    train=False
)

digits_train_data = EMNIST(
    root="data",
    split="digits",
    train=True
)

digits_test_data = EMNIST(
    root="data",
    split="digits",
    train=False
)

mnist_train_data = EMNIST(
    root="data",
    split="mnist",
    train=True
)

mnist_test_data = EMNIST(
    root="data",
    split="mnist",
    train=False
)

len(byclass_train_data), len(byclass_test_data)
# (697932, 116323)

len(balanced_train_data), len(balanced_test_data)
# (112800, 18800)

len(letters_train_data), len(letters_test_data)
# (124800, 20800)

len(digits_train_data), len(digits_test_data)
# (240000, 40000)

len(mnist_train_data), len(mnist_test_data)
# (60000, 10000)

byclass_train_data
# Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train

byclass_train_data.root
# 'data'

byclass_train_data.split
# 'byclass'

byclass_train_data.train
# True

print(byclass_train_data.transform)
# None

print(byclass_train_data.target_transform)
# None

byclass_train_data.download
# <bound method EMNIST.download of Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train>

len(byclass_train_data.classes), byclass_train_data.classes
# (62,
#  ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
#   'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
#   'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
#   'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
#   'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'])

byclass_train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 35)

byclass_train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 36)

byclass_train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 6)

byclass_train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)

byclass_train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 22)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in zip(range(1, 11), data):
        plt.subplot(2, 5, i)
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout()
    plt.show()

show_images(data=byclass_train_data, main_title="byclass_train_data")
show_images(data=byclass_test_data, main_title="byclass_test_data")

show_images(data=balanced_train_data, main_title="balanced_train_data")
show_images(data=balanced_test_data, main_title="balanced_test_data")

show_images(data=letters_train_data, main_title="letters_train_data")
show_images(data=letters_test_data, main_title="letters_test_data")

show_images(data=digits_train_data, main_title="digits_train_data")
show_images(data=digits_test_data, main_title="digits_test_data")

show_images(data=mnist_train_data, main_title="mnist_train_data")
show_images(data=mnist_test_data, main_title="mnist_test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

Image description

Image description

from torchvision.datasets import EMNIST
from torchvision.transforms import v2

tran = v2.Compose([v2.RandomHorizontalFlip(p=1.0),
                   v2.RandomRotation(degrees=(90, 90))])

byclass_train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=tran
)

byclass_test_data = EMNIST(
    root="data",
    split="byclass",
    train=False,
    transform=tran
)

balanced_train_data = EMNIST(
    root="data",
    split="balanced",
    train=True,
    transform=tran
)

balanced_test_data = EMNIST(
    root="data",
    split="balanced",
    train=False,
    transform=tran
)

letters_train_data = EMNIST(
    root="data",
    split="letters",
    train=True,
    transform=tran
)

letters_test_data = EMNIST(
    root="data",
    split="letters",
    train=False,
    transform=tran
)

digits_train_data = EMNIST(
    root="data",
    split="digits",
    train=True,
    transform=tran
)

digits_test_data = EMNIST(
    root="data",
    split="digits",
    train=False,
    transform=tran
)

mnist_train_data = EMNIST(
    root="data",
    split="mnist",
    train=True,
    transform=tran
)

mnist_test_data = EMNIST(
    root="data",
    split="mnist",
    train=False,
    transform=tran
)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in zip(range(1, 11), data):
        plt.subplot(2, 5, i)
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout()
    plt.show()

show_images(data=byclass_train_data, main_title="byclass_train_data")
show_images(data=byclass_test_data, main_title="byclass_test_data")

show_images(data=balanced_train_data, main_title="balanced_train_data")
show_images(data=balanced_test_data, main_title="balanced_test_data")

show_images(data=letters_train_data, main_title="letters_train_data")
show_images(data=letters_test_data, main_title="letters_test_data")

show_images(data=digits_train_data, main_title="digits_train_data")
show_images(data=digits_test_data, main_title="digits_test_data")

show_images(data=mnist_train_data, main_title="mnist_train_data")
show_images(data=mnist_test_data, main_title="mnist_test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

Image description

Image description

Image of Timescale

Timescale – the developer's data platform for modern apps, built on PostgreSQL

Timescale Cloud is PostgreSQL optimized for speed, scale, and performance. Over 3 million IoT, AI, crypto, and dev tool apps are powered by Timescale. Try it free today! No credit card required.

Try free

Top comments (0)

Sentry image

See why 4M developers consider Sentry, “not bad.”

Fixing code doesn’t have to be the worst part of your day. Learn how Sentry can help.

Learn more

👋 Kindness is contagious

Please leave a ❤️ or a friendly comment on this post if you found it helpful!

Okay