DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

ImageNet in PyTorch

Buy Me a Coffee

*My post explains ImageNet.

ImageNet() can use ImageNet dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Optional-Default:"train"-Type:str): *Memos:
    • "train"(1,281,167 images) or "val"(50,000 images) can be set to it.
    • "test"(100,000 images) isn't supported so I requested the feature on GitHub.
  • There is transform argument(Optional-Default:None-Type:callable). *transform= must be used.
  • There is target_transform argument(Optional-Default:None-Type:callable). - There is transform argument(Optional-Default:None-Type:callable). *target_transform= must be used.
  • There is loader argument(Optional-Default:torchvision.datasets.folder.default_loader-Type:callable). *loader= must be used.
  • You have to manually download the dataset(ILSVRC2012_devkit_t12.tar.gz, ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar to data/, then running ImageNet() extracts and loads the dataset.
  • About the label from the classes for the train and validation image indices respectively, tench&Tinca tinca(0) are 0~1299 and 0~49, goldfish&Carassius auratus(1) are 1300~2599 and 50~99, great white shark&white shark&man-eater&man-eating shark&Carcharodon carcharias(2) are 2600~3899 and 100~149, tiger shark&Galeocerdo cuvieri(3) are 3900~5199 and 150~199, hammerhead&hammerhead shark(4) are 5200~6499 and 200~249, electric ray&crampfish&numbfish&torpedo(5) are 6500~7799 and 250~299, stingray(6) is 7800~9099 and 250~299, cock(7) is 9100~10399 and 300~349, hen(8) is 10400~11699 and 350~399, ostrich&Struthio camelus(9) are 11700~12999 and 400~449, etc.
from torchvision.datasets import ImageNet
from torchvision.datasets.folder import default_loader

train_data = ImageNet(
    root="data"
)

train_data = ImageNet(
    root="data",
    split="train",
    transform=None,
    target_transform=None,
    loader=default_loader
)

val_data = ImageNet(
    root="data",
    split="val"
)

len(train_data), len(val_data)
# (1281167, 50000)

train_data
# Dataset ImageNet
#     Number of datapoints: 1281167
#     Root location: D:/data
#     Split: train

train_data.root
# 'data'

train_data.split
# 'train'

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.loader
# <function torchvision.datasets.folder.default_loader(path: str) -> Any>

len(train_data.classes), train_data.classes
# (1000,
#  [('tench', 'Tinca tinca'), ('goldfish', 'Carassius auratus'),
#   ('great white shark', 'white shark', 'man-eater', 'man-eating shark',
#    'Carcharodon carcharias'), ('tiger shark', 'Galeocerdo cuvieri'),
#   ('hammerhead', 'hammerhead shark'), ('electric ray', 'crampfish',
#    'numbfish', 'torpedo'), ('stingray',), ('cock',), ('hen',),
#   ('ostrich', 'Struthio camelus'), ..., ('bolete',), ('ear', 'spike',
#    'capitulum'), ('toilet tissue', 'toilet paper', 'bathroom tissue')])

train_data[0]
# (<PIL.Image.Image image mode=RGB size=250x250>, 0)

train_data[1]
# (<PIL.Image.Image image mode=RGB size=200x150>, 0)

train_data[2]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

train_data[1300]
# (<PIL.Image.Image image mode=RGB size=640x480>, 1)

train_data[2600]
# (<PIL.Image.Image image mode=RGB size=500x375>, 2)

val_data[0]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

val_data[1]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

val_data[2]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

val_data[50]
# (<PIL.Image.Image image mode=RGB size=500x500>, 1)

val_data[100]
# (<PIL.Image.Image image mode=RGB size=679x444>, 2)

import matplotlib.pyplot as plt

def show_images(data, ims, main_title=None):
    plt.figure(figsize=(12, 6))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, j in enumerate(iterable=ims, start=1):
        plt.subplot(2, 5, i)
        im, lab = data[j]
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout(h_pad=3.0)
    plt.show()

train_ims = (0, 1, 2, 1300, 2600, 3900, 5200, 6500, 7800, 9100)
val_ims = (0, 1, 2, 50, 100, 150, 200, 250, 300, 350)

show_images(data=train_data, ims=train_ims, main_title="train_data")
show_images(data=val_data, ims=val_ims, main_title="val_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Heroku

Build apps, not infrastructure.

Dealing with servers, hardware, and infrastructure can take up your valuable time. Discover the benefits of Heroku, the PaaS of choice for developers since 2007.

Visit Site

Top comments (0)

Heroku

This site is built on Heroku

Join the ranks of developers at Salesforce, Airbase, DEV, and more who deploy their mission critical applications on Heroku. Sign up today and launch your first app!

Get Started

👋 Kindness is contagious

Engage with a sea of insights in this enlightening article, highly esteemed within the encouraging DEV Community. Programmers of every skill level are invited to participate and enrich our shared knowledge.

A simple "thank you" can uplift someone's spirits. Express your appreciation in the comments section!

On DEV, sharing knowledge smooths our journey and strengthens our community bonds. Found this useful? A brief thank you to the author can mean a lot.

Okay