EMNIST() can use EMNIST dataset as shown below:
*Memos:
- The 1st argument is
root
(Required-Type:str
orpathlib.Path
). *An absolute or relative path is possible. - The 2nd argument is
split
(Required-Type:str
). *"byclass"
,"bymerge"
,"balanced"
,"letters"
,"digits"
or"mnist"
can be set to it. - There is
train
argument(Optional-Default:True
-Type:bool
): *Memos:-
train=
must be used. - For
split="byclass"
andsplit="byclass"
, if it'sTrue
, train data(697,932 images) is used while if it'sFalse
, test data(116,323 images) is used. - For
split="balanced"
, if it'sTrue
, train data(112,800 images) is used while if it'sFalse
, test data(188,00 images) is used. - For
split="letters"
, if it'sTrue
, train data(124,800 images) is used while if it'sFalse
, test data(20,800 images) is used. - For
split="digits"
, if it'sTrue
, train data(240,000 images) is used while if it'sFalse
, test data(40,000 images) is used. - For
split="mnist"
, if it'sTrue
, train data(60,000 images) is used while if it'sFalse
, test data(10,000 images) is used.
-
- There is
transform
argument(Optional-Default:None
-Type:callable
). *transform=
must be used. - There is
target_transform
argument(Optional-Default:None
-Type:callable
). *target_transform=
must be used. - There is
download
argument(Optional-Default:False
-Type:bool
): *Memos:-
download=
must be used. - If it's
True
, the dataset is downloaded from the internet and extracted(unzipped) toroot
. - If it's
True
and the dataset is already downloaded, it's extracted. - If it's
True
and the dataset is already downloaded and extracted, nothing happens. - It should be
False
if the dataset is already downloaded and extracted because it's faster. - You can manually download and extract the dataset(
emnist-byclass-train-images-idx3-ubyte.gz
,emnist-byclass-train-labels-idx1-ubyte.gz
,emnist-byclass-test-images-idx3-ubyte.gz
,emnist-byclass-test-labels-idx1-ubyte.gz
, etc) from here todata/EMNIST/raw/
.
-
- There is the bug which the images are flipped and rotated 90 degrees anticlockwise by default so they should be transformed.
from torchvision.datasets import EMNIST
byclass_train_data = EMNIST(
root="data",
split="byclass"
)
byclass_train_data = EMNIST(
root="data",
split="byclass",
train=True,
transform=None,
target_transform=None,
download=False
)
byclass_test_data = EMNIST(
root="data",
split="byclass",
train=False
)
balanced_train_data = EMNIST(
root="data",
split="balanced",
train=True
)
balanced_test_data = EMNIST(
root="data",
split="balanced",
train=False
)
letters_train_data = EMNIST(
root="data",
split="letters",
train=True
)
letters_test_data = EMNIST(
root="data",
split="letters",
train=False
)
digits_train_data = EMNIST(
root="data",
split="digits",
train=True
)
digits_test_data = EMNIST(
root="data",
split="digits",
train=False
)
mnist_train_data = EMNIST(
root="data",
split="mnist",
train=True
)
mnist_test_data = EMNIST(
root="data",
split="mnist",
train=False
)
len(byclass_train_data), len(byclass_test_data)
# (697932, 116323)
len(balanced_train_data), len(balanced_test_data)
# (112800, 18800)
len(letters_train_data), len(letters_test_data)
# (124800, 20800)
len(digits_train_data), len(digits_test_data)
# (240000, 40000)
len(mnist_train_data), len(mnist_test_data)
# (60000, 10000)
byclass_train_data
# Dataset EMNIST
# Number of datapoints: 697932
# Root location: data
# Split: Train
byclass_train_data.root
# 'data'
byclass_train_data.split
# 'byclass'
byclass_train_data.train
# True
print(byclass_train_data.transform)
# None
print(byclass_train_data.target_transform)
# None
byclass_train_data.download
# <bound method EMNIST.download of Dataset EMNIST
# Number of datapoints: 697932
# Root location: data
# Split: Train>
len(byclass_train_data.classes)
# 62
byclass_train_data.classes
# ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
# 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
# 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
# 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
# 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
byclass_train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 35)
byclass_train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 36)
byclass_train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 6)
byclass_train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)
byclass_train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 22)
import matplotlib.pyplot as plt
def show_images(data, main_title=None):
plt.figure(figsize=(10, 5))
plt.suptitle(t=main_title, y=1.0, fontsize=14)
for i, (im, lab) in enumerate(data, start=1):
plt.subplot(2, 5, i)
plt.title(label=lab)
plt.imshow(X=im)
if i == 10:
break
plt.tight_layout()
plt.show()
show_images(data=byclass_train_data, main_title="byclass_train_data")
show_images(data=byclass_test_data, main_title="byclass_test_data")
show_images(data=balanced_train_data, main_title="balanced_train_data")
show_images(data=balanced_test_data, main_title="balanced_test_data")
show_images(data=letters_train_data, main_title="letters_train_data")
show_images(data=letters_test_data, main_title="letters_test_data")
show_images(data=digits_train_data, main_title="digits_train_data")
show_images(data=digits_test_data, main_title="digits_test_data")
show_images(data=mnist_train_data, main_title="mnist_train_data")
show_images(data=mnist_test_data, main_title="mnist_test_data")
from torchvision.datasets import EMNIST
from torchvision.transforms import v2
tran = v2.Compose([v2.RandomHorizontalFlip(p=1.0),
v2.RandomRotation(degrees=(90, 90))])
byclass_train_data = EMNIST(
root="data",
split="byclass",
train=True,
transform=tran
)
byclass_test_data = EMNIST(
root="data",
split="byclass",
train=False,
transform=tran
)
balanced_train_data = EMNIST(
root="data",
split="balanced",
train=True,
transform=tran
)
balanced_test_data = EMNIST(
root="data",
split="balanced",
train=False,
transform=tran
)
letters_train_data = EMNIST(
root="data",
split="letters",
train=True,
transform=tran
)
letters_test_data = EMNIST(
root="data",
split="letters",
train=False,
transform=tran
)
digits_train_data = EMNIST(
root="data",
split="digits",
train=True,
transform=tran
)
digits_test_data = EMNIST(
root="data",
split="digits",
train=False,
transform=tran
)
mnist_train_data = EMNIST(
root="data",
split="mnist",
train=True,
transform=tran
)
mnist_test_data = EMNIST(
root="data",
split="mnist",
train=False,
transform=tran
)
import matplotlib.pyplot as plt
def show_images(data, main_title=None):
plt.figure(figsize=(10, 5))
plt.suptitle(t=main_title, y=1.0, fontsize=14)
for i, (im, lab) in enumerate(data, start=1):
plt.subplot(2, 5, i)
plt.title(label=lab)
plt.imshow(X=im)
if i == 10:
break
plt.tight_layout()
plt.show()
show_images(data=byclass_train_data, main_title="byclass_train_data")
show_images(data=byclass_test_data, main_title="byclass_test_data")
show_images(data=balanced_train_data, main_title="balanced_train_data")
show_images(data=balanced_test_data, main_title="balanced_test_data")
show_images(data=letters_train_data, main_title="letters_train_data")
show_images(data=letters_test_data, main_title="letters_test_data")
show_images(data=digits_train_data, main_title="digits_train_data")
show_images(data=digits_test_data, main_title="digits_test_data")
show_images(data=mnist_train_data, main_title="mnist_train_data")
show_images(data=mnist_test_data, main_title="mnist_test_data")
Top comments (0)