ImageNet() can use ImageNet dataset as shown below:
*Memos:
- The 1st argument is
root
(Required-Type:str
orpathlib.Path
). *An absolute or relative path is possible. - The 2nd argument is
split
(Optional-Default:"train"
-Type:str
): *Memos:-
"train"
(1,281,167 images) or"val"
(50,000 images) can be set to it. -
"test"
(100,000 images) isn't supported so I requested the feature on GitHub.
-
- There is
transform
argument(Optional-Default:None
-Type:callable
). *transform=
must be used. - There is
target_transform
argument(Optional-Default:None
-Type:callable
). - There istransform
argument(Optional-Default:None
-Type:callable
). *target_transform=
must be used. - There is
loader
argument(Optional-Default:torchvision.datasets.folder.default_loader
-Type:callable
). *loader=
must be used. - You have to manually download the dataset(ILSVRC2012_devkit_t12.tar.gz, ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar to
data/
, then runningImageNet()
extracts and loads the dataset. - About the label from the classes for the train and validation image indices respectively, tench&Tinca tinca(0) are 0~1299 and 0~49, goldfish&Carassius auratus(1) are 1300~2599 and 50~99, great white shark&white shark&man-eater&man-eating shark&Carcharodon carcharias(2) are 2600~3899 and 100~149, tiger shark&Galeocerdo cuvieri(3) are 3900~5199 and 150~199, hammerhead&hammerhead shark(4) are 5200~6499 and 200~249, electric ray&crampfish&numbfish&torpedo(5) are 6500~7799 and 250~299, stingray(6) is 7800~9099 and 250~299, cock(7) is 9100~10399 and 300~349, hen(8) is 10400~11699 and 350~399, ostrich&Struthio camelus(9) are 11700~12999 and 400~449, etc.
from torchvision.datasets import ImageNet
from torchvision.datasets.folder import default_loader
train_data = ImageNet(
root="data"
)
train_data = ImageNet(
root="data",
split="train",
transform=None,
target_transform=None,
loader=default_loader
)
val_data = ImageNet(
root="data",
split="val"
)
len(train_data), len(val_data)
# (1281167, 50000)
train_data
# Dataset ImageNet
# Number of datapoints: 1281167
# Root location: D:/data
# Split: train
train_data.root
# 'data'
train_data.split
# 'train'
print(train_data.transform)
# None
print(train_data.target_transform)
# None
train_data.loader
# <function torchvision.datasets.folder.default_loader(path: str) -> Any>
len(train_data.classes), train_data.classes
# (1000,
# [('tench', 'Tinca tinca'), ('goldfish', 'Carassius auratus'),
# ('great white shark', 'white shark', 'man-eater', 'man-eating shark',
# 'Carcharodon carcharias'), ('tiger shark', 'Galeocerdo cuvieri'),
# ('hammerhead', 'hammerhead shark'), ('electric ray', 'crampfish',
# 'numbfish', 'torpedo'), ('stingray',), ('cock',), ('hen',),
# ('ostrich', 'Struthio camelus'), ..., ('bolete',), ('ear', 'spike',
# 'capitulum'), ('toilet tissue', 'toilet paper', 'bathroom tissue')])
train_data[0]
# (<PIL.Image.Image image mode=RGB size=250x250>, 0)
train_data[1]
# (<PIL.Image.Image image mode=RGB size=200x150>, 0)
train_data[2]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)
train_data[1300]
# (<PIL.Image.Image image mode=RGB size=640x480>, 1)
train_data[2600]
# (<PIL.Image.Image image mode=RGB size=500x375>, 2)
val_data[0]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)
val_data[1]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)
val_data[2]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)
val_data[50]
# (<PIL.Image.Image image mode=RGB size=500x500>, 1)
val_data[100]
# (<PIL.Image.Image image mode=RGB size=679x444>, 2)
import matplotlib.pyplot as plt
def show_images(data, ims, main_title=None):
plt.figure(figsize=[12, 6])
plt.suptitle(t=main_title, y=1.0, fontsize=14)
for i, j in enumerate(iterable=ims, start=1):
plt.subplot(2, 5, i)
im, lab = data[j]
plt.imshow(X=im)
plt.title(label=lab)
plt.tight_layout(h_pad=3.0)
plt.show()
train_ims = [0, 1, 2, 1300, 2600, 3900, 5200, 6500, 7800, 9100]
val_ims = [0, 1, 2, 50, 100, 150, 200, 250, 300, 350]
show_images(data=train_data, ims=train_ims, main_title="train_data")
show_images(data=val_data, ims=val_ims, main_title="val_data")
Top comments (0)