DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

CelebA in PyTorch

Buy Me a Coffee

*My post explains CelebA.

CelebA() can use CelebA dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Optional-Default:"train"-Type:str). *"train"(162,770 images), "valid"(19,867 images), "test"(19,962 images) or "all"(202,599 images) can be set to it.
  • The 3rd argument is target_type(Optional-Default:"attr"-Type:str or list of str): *Memos:
    • "attr", "identity", "bbox" and/or "landmarks" can be set to it.
    • An empty list can also be set to it.
    • The multiple same values can be set to it.
    • If the order of values is different, the order of their elements is also different.
  • The 4th argument is transform(Optional-Default:None-Type:callable).
  • The 5th argument is target_transform(Optional-Default:None-Type:callable).
  • The 6th argument is download(Optional-Default:False-Type:bool): *Memos:
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • gdown is required to download the dataset.
    • You can manually download and extract the dataset(img_align_celeba.zip with identity_CelebA.txt, list_attr_celeba.txt, list_bbox_celeba.txt, list_eval_partition.txt and list_landmarks_align_celeba.txt) from here to data/celeba/.
from torchvision.datasets import CelebA

train_attr_data = CelebA(
    root="data"
)

train_attr_data = CelebA(
    root="data",
    split="train",
    target_type="attr",
    transform=None,
    target_transform=None,
    download=False
)

valid_identity_data = CelebA(
    root="data",
    split="valid",
    target_type="identity"
)

test_bbox_data = CelebA(
    root="data",
    split="test",
    target_type="bbox"
)

all_landmarks_data = CelebA(
    root="data",
    split="all",
    target_type="landmarks"
)

all_empty_data = CelebA(
    root="data",
    split="all",
    target_type=[]
)

all_all_data = CelebA(
    root="data",
    split="all",
    target_type=["attr", "identity", "bbox", "landmarks"]
)

len(train_attr_data), len(valid_identity_data), len(test_bbox_data)
# (162770, 19867, 19962)

len(all_landmarks_data), len(all_empty_data), len(all_all_data)
# (202599, 202599, 202599)

train_attr_data
# Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train

train_attr_data.root
# 'data'

train_attr_data.split
# 'train'

train_attr_data.target_type
# ['attr']

print(train_attr_data.transform)
# None

print(train_attr_data.target_transform)
# None

train_attr_data.download
# <bound method CelebA.download of Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train>

len(train_attr_data.attr), train_attr_data.attr
# (162770,
#  tensor([[0, 1, 1, ..., 0, 0, 1],
#          [0, 0, 0, ..., 0, 0, 1],
#          [0, 0, 0, ..., 0, 0, 1],
#          ...,
#          [1, 0, 1, ..., 0, 1, 1],
#          [0, 0, 0, ..., 0, 0, 1],
#          [0, 1, 1, ..., 1, 0, 1]]))

len(train_attr_data.attr_names), train_attr_data.attr_names
# (41,
#  ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 
#   'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
#   'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair',
#   ...
#   'Wearing_Necklace', 'Wearing_Necktie', 'Young', ''])

len(train_attr_data.identity), len(train_attr_data.identity.unique())
# (162770, 8192)

train_attr_data.identity
# tensor([[2880], [2937], [8692], ..., [7391], [8610], [2304]])

len(train_attr_data.bbox), train_attr_data.bbox
# (162770,
#  tensor([[95, 71, 226, 313],
#          [72, 94, 221, 306],
#          [216, 59, 91, 126],
#          ...,
#          [103, 103, 143, 198],
#          [30, 59, 216, 280],
#          [376, 4, 372, 515]]))

len(train_attr_data.landmarks_align), train_attr_data.landmarks_align
# (162770,
#  tensor([[69, 109, 106, ..., 152, 108, 154],
#          [69, 110, 107, ..., 151, 108, 153],
#          [76, 112, 104, ..., 156, 98, 158],
#          ...,
#          [69, 113, 109, ..., 151, 110, 151],
#          [68, 112, 109, ..., 150, 108, 151],
#          [70, 111, 107, ..., 153, 102, 152]]))

train_attr_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#          0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#          0, 1, 1, 0, 1, 0, 1, 0, 0, 1]))

train_attr_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#          0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1]))

train_attr_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#          1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#          0, 0, 0, 1, 0, 0, 0, 0, 0, 1]))

valid_identity_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2594))

valid_identity_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2795))

valid_identity_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(947))

test_bbox_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([147, 82, 120, 166]))

test_bbox_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([106, 34, 140, 194]))

test_bbox_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([107, 78, 109, 151]))

all_landmarks_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154]))

all_landmarks_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153]))

all_landmarks_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158]))

all_empty_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_empty_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_empty_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_all_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#           0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#           0, 1, 1, 0, 1, 0, 1, 0, 0, 1]),
#   tensor(2880),
#   tensor([95, 71, 226, 313]),
#   tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154])))

all_all_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#           0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1]),
#   tensor(2937),
#   tensor([72, 94, 221, 306]),
#   tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153])))

all_all_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#           1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#           1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#           0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
#  tensor(8692),
#  tensor([216, 59, 91, 126]),
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158])))

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.patches import Circle

def show_images(data, main_title=None):
    if "attr" in data.target_type and len(data.target_type) == 1 \
        or not data.target_type:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, _) in zip(range(1, 11), data):
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            # if i == 10:
            #     break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "identity" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lab) in zip(range(1, 11), data):
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            plt.title(label=lab.item())
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "bbox" in data.target_type and len(data.target_type) == 1:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (i, (im, (x, y, w, h))), axis \
            in zip(zip(range(1, 11), data), axes.ravel()):
            axis.imshow(X=im)
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none')
            axis.add_patch(p=rect)
        fig.tight_layout(h_pad=3.0)
        plt.show()
    elif "landmarks" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lm) in zip(range(1, 11), data):
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            for px, py in lm.split(2):
                plt.scatter(x=px, y=py, c='#1f77b4')
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif len(data.target_type) == 4:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (im, (_, lab, (x, y, w, h), lm)), axis in zip(data, axes.ravel()):
            axis.imshow(X=im)
            axis.set_title(label=lab.item())
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none', clip_on=True)
            axis.add_patch(p=rect)

# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓

            axis.autoscale(enable=False) # This is important otherwise
                                         # the images are shrinked
            for px, py in lm.split(2):
                axis.scatter(x=px, y=py, c='#1f77b4')

# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑

# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
            # You can also use it
            # for px, py in lm.split(2):
            #     axis.add_patch(p=Circle(xy=(px, py)))

# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑

# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
            # You can also use it
            # axis.autoscale(enable=False) # This is important otherwise
            #                              # the images are shrinked
            # px = []
            # py = []
            # for j, v in enumerate(lm):
            #     if j%2 == 0:
            #         px.append(v)
            #     else:
            #         py.append(v)
            # axis.plot(px, py)

# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑

        fig.tight_layout(h_pad=3.0)
        plt.show()

show_images(data=train_attr_data, main_title="train_attr_data")
show_images(data=valid_identity_data, main_title="valid_identity_data")
show_images(data=test_bbox_data, main_title="test_bbox_data")
show_images(data=all_landmarks_data, main_title="all_landmarks_data")
show_images(data=all_empty_data, main_title="all_empty_data")
show_images(data=all_all_data, main_title="all_all_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

Image description

Image description

Image description

Image of Timescale

🚀 pgai Vectorizer: SQLAlchemy and LiteLLM Make Vector Search Simple

We built pgai Vectorizer to simplify embedding management for AI applications—without needing a separate database or complex infrastructure. Since launch, developers have created over 3,000 vectorizers on Timescale Cloud, with many more self-hosted.

Read more

Top comments (0)

Postmark Image

Speedy emails, satisfied customers

Are delayed transactional emails costing you user satisfaction? Postmark delivers your emails almost instantly, keeping your customers happy and connected.

Sign up