DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

EMNIST in PyTorch

Buy Me a Coffee

*My post explains EMNIST.

EMNIST() can use EMNIST dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Required-Type:str). *"byclass", "bymerge", "balanced", "letters", "digits" or "mnist" can be set to it.
  • There is train argument(Optional-Default:True-Type:bool): *Memos:
    • train= must be used.
    • For split="byclass" and split="byclass", if it's True, train data(697,932 images) is used while if it's False, test data(116,323 images) is used.
    • For split="balanced", if it's True, train data(112,800 images) is used while if it's False, test data(188,00 images) is used.
    • For split="letters", if it's True, train data(124,800 images) is used while if it's False, test data(20,800 images) is used.
    • For split="digits", if it's True, train data(240,000 images) is used while if it's False, test data(40,000 images) is used.
    • For split="mnist", if it's True, train data(60,000 images) is used while if it's False, test data(10,000 images) is used.
  • There is transform argument(Optional-Default:None-Type:callable). *transform= must be used.
  • There is target_transform argument(Optional-Default:None-Type:callable). *target_transform= must be used.
  • There is download argument(Optional-Default:False-Type:bool): *Memos:
    • download= must be used.
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • You can manually download and extract the dataset(emnist-byclass-train-images-idx3-ubyte.gz, emnist-byclass-train-labels-idx1-ubyte.gz, emnist-byclass-test-images-idx3-ubyte.gz, emnist-byclass-test-labels-idx1-ubyte.gz, etc) from here to data/EMNIST/raw/.
  • There is the bug which the images are flipped and rotated 90 degrees anticlockwise by default so they should be transformed.
from torchvision.datasets import EMNIST

byclass_train_data = EMNIST(
    root="data",
    split="byclass"
)

byclass_train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

byclass_test_data = EMNIST(
    root="data",
    split="byclass",
    train=False
)

balanced_train_data = EMNIST(
    root="data",
    split="balanced",
    train=True
)

balanced_test_data = EMNIST(
    root="data",
    split="balanced",
    train=False
)

letters_train_data = EMNIST(
    root="data",
    split="letters",
    train=True
)

letters_test_data = EMNIST(
    root="data",
    split="letters",
    train=False
)

digits_train_data = EMNIST(
    root="data",
    split="digits",
    train=True
)

digits_test_data = EMNIST(
    root="data",
    split="digits",
    train=False
)

mnist_train_data = EMNIST(
    root="data",
    split="mnist",
    train=True
)

mnist_test_data = EMNIST(
    root="data",
    split="mnist",
    train=False
)

len(byclass_train_data), len(byclass_test_data)
# (697932, 116323)

len(balanced_train_data), len(balanced_test_data)
# (112800, 18800)

len(letters_train_data), len(letters_test_data)
# (124800, 20800)

len(digits_train_data), len(digits_test_data)
# (240000, 40000)

len(mnist_train_data), len(mnist_test_data)
# (60000, 10000)

byclass_train_data
# Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train

byclass_train_data.root
# 'data'

byclass_train_data.split
# 'byclass'

byclass_train_data.train
# True

print(byclass_train_data.transform)
# None

print(byclass_train_data.target_transform)
# None

byclass_train_data.download
# <bound method EMNIST.download of Dataset EMNIST
#     Number of datapoints: 697932
#     Root location: data
#     Split: Train>

len(byclass_train_data.classes)
# 62

byclass_train_data.classes
# ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
#  'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
#  'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
#  'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
#  'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

byclass_train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 35)

byclass_train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 36)

byclass_train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 6)

byclass_train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)

byclass_train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 22)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in enumerate(data, start=1):
        plt.subplot(2, 5, i)
        plt.title(label=lab)
        plt.imshow(X=im)
        if i == 10:
            break
    plt.tight_layout()
    plt.show()

show_images(data=byclass_train_data, main_title="byclass_train_data")
show_images(data=byclass_test_data, main_title="byclass_test_data")

show_images(data=balanced_train_data, main_title="balanced_train_data")
show_images(data=balanced_test_data, main_title="balanced_test_data")

show_images(data=letters_train_data, main_title="letters_train_data")
show_images(data=letters_test_data, main_title="letters_test_data")

show_images(data=digits_train_data, main_title="digits_train_data")
show_images(data=digits_test_data, main_title="digits_test_data")

show_images(data=mnist_train_data, main_title="mnist_train_data")
show_images(data=mnist_test_data, main_title="mnist_test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

Image description

Image description

from torchvision.datasets import EMNIST
from torchvision.transforms import v2

tran = v2.Compose([v2.RandomHorizontalFlip(p=1.0),
                   v2.RandomRotation(degrees=(90, 90))])

byclass_train_data = EMNIST(
    root="data",
    split="byclass",
    train=True,
    transform=tran
)

byclass_test_data = EMNIST(
    root="data",
    split="byclass",
    train=False,
    transform=tran
)

balanced_train_data = EMNIST(
    root="data",
    split="balanced",
    train=True,
    transform=tran
)

balanced_test_data = EMNIST(
    root="data",
    split="balanced",
    train=False,
    transform=tran
)

letters_train_data = EMNIST(
    root="data",
    split="letters",
    train=True,
    transform=tran
)

letters_test_data = EMNIST(
    root="data",
    split="letters",
    train=False,
    transform=tran
)

digits_train_data = EMNIST(
    root="data",
    split="digits",
    train=True,
    transform=tran
)

digits_test_data = EMNIST(
    root="data",
    split="digits",
    train=False,
    transform=tran
)

mnist_train_data = EMNIST(
    root="data",
    split="mnist",
    train=True,
    transform=tran
)

mnist_test_data = EMNIST(
    root="data",
    split="mnist",
    train=False,
    transform=tran
)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 5))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (im, lab) in enumerate(data, start=1):
        plt.subplot(2, 5, i)
        plt.title(label=lab)
        plt.imshow(X=im)
        if i == 10:
            break
    plt.tight_layout()
    plt.show()

show_images(data=byclass_train_data, main_title="byclass_train_data")
show_images(data=byclass_test_data, main_title="byclass_test_data")

show_images(data=balanced_train_data, main_title="balanced_train_data")
show_images(data=balanced_test_data, main_title="balanced_test_data")

show_images(data=letters_train_data, main_title="letters_train_data")
show_images(data=letters_test_data, main_title="letters_test_data")

show_images(data=digits_train_data, main_title="digits_train_data")
show_images(data=digits_test_data, main_title="digits_test_data")

show_images(data=mnist_train_data, main_title="mnist_train_data")
show_images(data=mnist_test_data, main_title="mnist_test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

Image description

Image description

Top comments (0)