DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

1

EMNIST in PyTorch

Buy Me a Coffee

*Memos:

EMNIST() can use EMNIST dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Required-Type:str). *"byclass", "bymerge", "balanced", "letters", "digits" or "mnist" can be set to it.
  • There is train argument(Optional-Default:True-Type:bool): *Memos:
    • train= must be used.
    • For split="byclass" and split="byclass", if it's True, train data(697,932 images) is used while if it's False, test data(116,323 images) is used.
    • For split="balanced", if it's True, train data(112,800 images) is used while if it's False, test data(188,00 images) is used.
    • For split="letters", if it's True, train data(124,800 images) is used while if it's False, test data(20,800 images) is used.
    • For split="digits", if it's True, train data(240,000 images) is used while if it's False, test data(40,000 images) is used.
    • For split="mnist", if it's True, train data(60,000 images) is used while if it's False, test data(10,000 images) is used.
  • There is transform argument(Optional-Default:None-Type:callable). *transform= must be used.
  • There is target_transform argument(Optional-Default:None-Type:callable). *target_transform= must be used.
  • There is download argument(Optional-Default:False-Type:bool): *Memos:
    • download= must be used.
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • You can manually download and extract the dataset(emnist-byclass-train-images-idx3-ubyte.gz, emnist-byclass-train-labels-idx1-ubyte.gz, emnist-byclass-test-images-idx3-ubyte.gz, emnist-byclass-test-labels-idx1-ubyte.gz, etc) from here to data/EMNIST/raw/.
  • There is the bug which the images are flipped and rotated 90 degrees anticlockwise by default so they should be transformed.
from torchvision.datasets import EMNIST

byclass_train_data = EMNIST(
    root="data",
    split="byclass"
)

byclass_train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

byclass_test_data = EMNIST(
    root="data",
    split="byclass",
    train=False
)

balanced_train_data = EMNIST(
    root="data",
    split="balanced",
    train=True
)

balanced_test_data = EMNIST(
    root="data",
    split="balanced",
    train=False
)

letters_train_data = EMNIST(
    root="data",
    split="letters",
    train=True
)

letters_test_data = EMNIST(
    root="data",
    split="letters",
    train=False
)

digits_train_data = EMNIST(
    root="data",
    split="digits",
    train=True
)

digits_test_data = EMNIST(
    root="data",
    split="digits",
    train=False
)

mnist_train_data = EMNIST(
    root="data",
    split="mnist",
    train=True
)

mnist_test_data = EMNIST(
    root="data",
    split="mnist",
    train=False
)

len(byclass_train_data), len(byclass_test_data)
# (697932, 116323)

len(balanced_train_data), len(balanced_test_data)
# (112800, 18800)

len(letters_train_data), len(letters_test_data)
# (124800, 20800)

len(digits_train_data), len(digits_test_data)
# (240000, 40000)

len(mnist_train_data), len(mnist_test_data)
# (60000, 10000)

byclass_train_data
# Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train

byclass_train_data.root
# 'data'

byclass_train_data.split
# 'byclass'

byclass_train_data.train
# True

print(byclass_train_data.transform)
# None

print(byclass_train_data.target_transform)
# None

byclass_train_data.download
# <bound method EMNIST.download of Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train>

len(byclass_train_data.classes), byclass_train_data.classes
# (62,
#  ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
#   'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
#   'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
#   'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
#   'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'])

byclass_train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 35)

byclass_train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 36)

byclass_train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 6)

byclass_train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)

byclass_train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 22)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in zip(range(1, 11), data):
        plt.subplot(2, 5, i)
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout()
    plt.show()

show_images(data=byclass_train_data, main_title="byclass_train_data")
show_images(data=byclass_test_data, main_title="byclass_test_data")

show_images(data=balanced_train_data, main_title="balanced_train_data")
show_images(data=balanced_test_data, main_title="balanced_test_data")

show_images(data=letters_train_data, main_title="letters_train_data")
show_images(data=letters_test_data, main_title="letters_test_data")

show_images(data=digits_train_data, main_title="digits_train_data")
show_images(data=digits_test_data, main_title="digits_test_data")

show_images(data=mnist_train_data, main_title="mnist_train_data")
show_images(data=mnist_test_data, main_title="mnist_test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

Image description

Image description

from torchvision.datasets import EMNIST
from torchvision.transforms import v2

tran = v2.Compose([v2.RandomHorizontalFlip(p=1.0),
                   v2.RandomRotation(degrees=(90, 90))])

byclass_train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=tran
)

byclass_test_data = EMNIST(
    root="data",
    split="byclass",
    train=False,
    transform=tran
)

balanced_train_data = EMNIST(
    root="data",
    split="balanced",
    train=True,
    transform=tran
)

balanced_test_data = EMNIST(
    root="data",
    split="balanced",
    train=False,
    transform=tran
)

letters_train_data = EMNIST(
    root="data",
    split="letters",
    train=True,
    transform=tran
)

letters_test_data = EMNIST(
    root="data",
    split="letters",
    train=False,
    transform=tran
)

digits_train_data = EMNIST(
    root="data",
    split="digits",
    train=True,
    transform=tran
)

digits_test_data = EMNIST(
    root="data",
    split="digits",
    train=False,
    transform=tran
)

mnist_train_data = EMNIST(
    root="data",
    split="mnist",
    train=True,
    transform=tran
)

mnist_test_data = EMNIST(
    root="data",
    split="mnist",
    train=False,
    transform=tran
)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in zip(range(1, 11), data):
        plt.subplot(2, 5, i)
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout()
    plt.show()

show_images(data=byclass_train_data, main_title="byclass_train_data")
show_images(data=byclass_test_data, main_title="byclass_test_data")

show_images(data=balanced_train_data, main_title="balanced_train_data")
show_images(data=balanced_test_data, main_title="balanced_test_data")

show_images(data=letters_train_data, main_title="letters_train_data")
show_images(data=letters_test_data, main_title="letters_test_data")

show_images(data=digits_train_data, main_title="digits_train_data")
show_images(data=digits_test_data, main_title="digits_test_data")

show_images(data=mnist_train_data, main_title="mnist_train_data")
show_images(data=mnist_test_data, main_title="mnist_test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

Image description

Image description

Sentry image

Make it make sense

Make sense of fixing your code with straight-forward application monitoring.

Start debugging →

Top comments (0)

Sentry image

Make it make sense

Make sense of fixing your code with straight-forward application monitoring.

Start debugging →

👋 Kindness is contagious

Explore a trove of insights in this engaging article, celebrated within our welcoming DEV Community. Developers from every background are invited to join and enhance our shared wisdom.

A genuine "thank you" can truly uplift someone’s day. Feel free to express your gratitude in the comments below!

On DEV, our collective exchange of knowledge lightens the road ahead and strengthens our community bonds. Found something valuable here? A small thank you to the author can make a big difference.

Okay