DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

ImageNet in PyTorch

Buy Me a Coffee

*My post explains ImageNet.

ImageNet() can use ImageNet dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Optional-Default:"train"-Type:str): *Memos:
    • "train"(1,281,167 images) or "val"(50,000 images) can be set to it.
    • "test"(100,000 images) isn't supported so I requested the feature on GitHub.
  • There is transform argument(Optional-Default:None-Type:callable). *transform= must be used.
  • There is target_transform argument(Optional-Default:None-Type:callable). - There is transform argument(Optional-Default:None-Type:callable). *target_transform= must be used.
  • There is loader argument(Optional-Default:torchvision.datasets.folder.default_loader-Type:callable). *loader= must be used.
  • You have to manually download the dataset(ILSVRC2012_devkit_t12.tar.gz, ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar to data/, then running ImageNet() extracts and loads the dataset.
  • About the label from the classes for the train and validation image indices respectively, tench&Tinca tinca(0) are 0~1299 and 0~49, goldfish&Carassius auratus(1) are 1300~2599 and 50~99, great white shark&white shark&man-eater&man-eating shark&Carcharodon carcharias(2) are 2600~3899 and 100~149, tiger shark&Galeocerdo cuvieri(3) are 3900~5199 and 150~199, hammerhead&hammerhead shark(4) are 5200~6499 and 200~249, electric ray&crampfish&numbfish&torpedo(5) are 6500~7799 and 250~299, stingray(6) is 7800~9099 and 250~299, cock(7) is 9100~10399 and 300~349, hen(8) is 10400~11699 and 350~399, ostrich&Struthio camelus(9) are 11700~12999 and 400~449, etc.
from torchvision.datasets import ImageNet
from torchvision.datasets.folder import default_loader

train_data = ImageNet(
    root="data"
)

train_data = ImageNet(
    root="data",
    split="train",
    transform=None,
    target_transform=None,
    loader=default_loader
)

val_data = ImageNet(
    root="data",
    split="val"
)

len(train_data), len(val_data)
# (1281167, 50000)

train_data
# Dataset ImageNet
#     Number of datapoints: 1281167
#     Root location: D:/data
#     Split: train

train_data.root
# 'data'

train_data.split
# 'train'

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.loader
# <function torchvision.datasets.folder.default_loader(path: str) -> Any>

len(train_data.classes), train_data.classes
# (1000,
#  [('tench', 'Tinca tinca'), ('goldfish', 'Carassius auratus'),
#   ('great white shark', 'white shark', 'man-eater', 'man-eating shark',
#    'Carcharodon carcharias'), ('tiger shark', 'Galeocerdo cuvieri'),
#   ('hammerhead', 'hammerhead shark'), ('electric ray', 'crampfish',
#    'numbfish', 'torpedo'), ('stingray',), ('cock',), ('hen',),
#   ('ostrich', 'Struthio camelus'), ..., ('bolete',), ('ear', 'spike',
#    'capitulum'), ('toilet tissue', 'toilet paper', 'bathroom tissue')])

train_data[0]
# (<PIL.Image.Image image mode=RGB size=250x250>, 0)

train_data[1]
# (<PIL.Image.Image image mode=RGB size=200x150>, 0)

train_data[2]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

train_data[1300]
# (<PIL.Image.Image image mode=RGB size=640x480>, 1)

train_data[2600]
# (<PIL.Image.Image image mode=RGB size=500x375>, 2)

val_data[0]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

val_data[1]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

val_data[2]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

val_data[50]
# (<PIL.Image.Image image mode=RGB size=500x500>, 1)

val_data[100]
# (<PIL.Image.Image image mode=RGB size=679x444>, 2)

import matplotlib.pyplot as plt

def show_images(data, ims, main_title=None):
    plt.figure(figsize=(12, 6))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, j in enumerate(iterable=ims, start=1):
        plt.subplot(2, 5, i)
        im, lab = data[j]
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout(h_pad=3.0)
    plt.show()

train_ims = (0, 1, 2, 1300, 2600, 3900, 5200, 6500, 7800, 9100)
val_ims = (0, 1, 2, 50, 100, 150, 200, 250, 300, 350)

show_images(data=train_data, ims=train_ims, main_title="train_data")
show_images(data=val_data, ims=val_ims, main_title="val_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Billboard image

Imagine monitoring that's actually built for developers

Join Vercel, CrowdStrike, and thousands of other teams that trust Checkly to streamline monitor creation and configuration with Monitoring as Code.

Start Monitoring

Top comments (0)

Image of Timescale

Timescale – the developer's data platform for modern apps, built on PostgreSQL

Timescale Cloud is PostgreSQL optimized for speed, scale, and performance. Over 3 million IoT, AI, crypto, and dev tool apps are powered by Timescale. Try it free today! No credit card required.

Try free

👋 Kindness is contagious

Explore a sea of insights with this enlightening post, highly esteemed within the nurturing DEV Community. Coders of all stripes are invited to participate and contribute to our shared knowledge.

Expressing gratitude with a simple "thank you" can make a big impact. Leave your thanks in the comments!

On DEV, exchanging ideas smooths our way and strengthens our community bonds. Found this useful? A quick note of thanks to the author can mean a lot.

Okay