CelebA() can use CelebA dataset as shown below:
*Memos:
- The 1st argument is root(Required-Type:strorpathlib.Path). *An absolute or relative path is possible.
- The 2nd argument is split(Optional-Default:"train"-Type:str). *"train"(162,770 images),"valid"(19,867 images),"test"(19,962 images) or"all"(202,599 images) can be set to it.
- The 3rd argument is target_type(Optional-Default:"attr"-Type:strorlistofstr): *Memos:- 
"attr","identity","bbox"and/or"landmarks"can be set to it.
- An empty list can also be set to it.
- The multiple same values can be set to it.
- If the order of values is different, the order of their elements is also different.
 
- 
- The 4th argument is transform(Optional-Default:None-Type:callable).
- The 5th argument is target_transform(Optional-Default:None-Type:callable).
- The 6th argument is download(Optional-Default:False-Type:bool): *Memos:- If it's True, the dataset is downloaded from the internet and extracted(unzipped) toroot.
- If it's Trueand the dataset is already downloaded, it's extracted.
- If it's Trueand the dataset is already downloaded and extracted, nothing happens.
- It should be Falseif the dataset is already downloaded and extracted because it's faster.
- gdown is required to download the dataset.
- You can manually download and extract the dataset(img_align_celeba.zipwithidentity_CelebA.txt,list_attr_celeba.txt,list_bbox_celeba.txt,list_eval_partition.txtandlist_landmarks_align_celeba.txt) from here todata/celeba/.
 
- If it's 
from torchvision.datasets import CelebA
train_attr_data = CelebA(
    root="data"
)
train_attr_data = CelebA(
    root="data",
    split="train",
    target_type="attr",
    transform=None,
    target_transform=None,
    download=False
)
valid_identity_data = CelebA(
    root="data",
    split="valid",
    target_type="identity"
)
test_bbox_data = CelebA(
    root="data",
    split="test",
    target_type="bbox"
)
all_landmarks_data = CelebA(
    root="data",
    split="all",
    target_type="landmarks"
)
all_empty_data = CelebA(
    root="data",
    split="all",
    target_type=[]
)
all_all_data = CelebA(
    root="data",
    split="all",
    target_type=["attr", "identity", "bbox", "landmarks"]
)
len(train_attr_data), len(valid_identity_data), len(test_bbox_data)
# (162770, 19867, 19962)
len(all_landmarks_data), len(all_empty_data), len(all_all_data)
# (202599, 202599, 202599)
train_attr_data
# Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train
train_attr_data.root
# 'data'
train_attr_data.split
# 'train'
train_attr_data.target_type
# ['attr']
print(train_attr_data.transform)
# None
print(train_attr_data.target_transform)
# None
train_attr_data.download
# <bound method CelebA.download of Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train>
len(train_attr_data.attr), train_attr_data.attr
# (162770,
#  tensor([[0, 1, 1, ..., 0, 0, 1],
#          [0, 0, 0, ..., 0, 0, 1],
#          [0, 0, 0, ..., 0, 0, 1],
#          ...,
#          [1, 0, 1, ..., 0, 1, 1],
#          [0, 0, 0, ..., 0, 0, 1],
#          [0, 1, 1, ..., 1, 0, 1]]))
len(train_attr_data.attr_names), train_attr_data.attr_names
# (41,
#  ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 
#   'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
#   'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair',
#   ...
#   'Wearing_Necklace', 'Wearing_Necktie', 'Young', ''])
len(train_attr_data.identity), len(train_attr_data.identity.unique())
# (162770, 8192)
train_attr_data.identity
# tensor([[2880], [2937], [8692], ..., [7391], [8610], [2304]])
len(train_attr_data.bbox), train_attr_data.bbox
# (162770,
#  tensor([[95, 71, 226, 313],
#          [72, 94, 221, 306],
#          [216, 59, 91, 126],
#          ...,
#          [103, 103, 143, 198],
#          [30, 59, 216, 280],
#          [376, 4, 372, 515]]))
len(train_attr_data.landmarks_align), train_attr_data.landmarks_align
# (162770,
#  tensor([[69, 109, 106, ..., 152, 108, 154],
#          [69, 110, 107, ..., 151, 108, 153],
#          [76, 112, 104, ..., 156, 98, 158],
#          ...,
#          [69, 113, 109, ..., 151, 110, 151],
#          [68, 112, 109, ..., 150, 108, 151],
#          [70, 111, 107, ..., 153, 102, 152]]))
train_attr_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#          0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#          0, 1, 1, 0, 1, 0, 1, 0, 0, 1]))
train_attr_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#          0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1]))
train_attr_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#          1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#          0, 0, 0, 1, 0, 0, 0, 0, 0, 1]))
valid_identity_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2594))
valid_identity_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2795))
valid_identity_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(947))
test_bbox_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([147, 82, 120, 166]))
test_bbox_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([106, 34, 140, 194]))
test_bbox_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([107, 78, 109, 151]))
all_landmarks_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154]))
all_landmarks_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153]))
all_landmarks_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158]))
all_empty_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)
all_empty_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)
all_empty_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)
all_all_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#           0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#           0, 1, 1, 0, 1, 0, 1, 0, 0, 1]),
#   tensor(2880),
#   tensor([95, 71, 226, 313]),
#   tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154])))
all_all_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#           0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1]),
#   tensor(2937),
#   tensor([72, 94, 221, 306]),
#   tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153])))
all_all_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#           1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#           1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#           0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
#  tensor(8692),
#  tensor([216, 59, 91, 126]),
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158])))
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.patches import Circle
def show_images(data, main_title=None):
    if "attr" in data.target_type and len(data.target_type) == 1 \
        or not data.target_type:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, _) in zip(range(1, 11), data):
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            # if i == 10:
            #     break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "identity" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lab) in zip(range(1, 11), data):
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            plt.title(label=lab.item())
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "bbox" in data.target_type and len(data.target_type) == 1:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (i, (im, (x, y, w, h))), axis \
            in zip(zip(range(1, 11), data), axes.ravel()):
            axis.imshow(X=im)
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none')
            axis.add_patch(p=rect)
        fig.tight_layout(h_pad=3.0)
        plt.show()
    elif "landmarks" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lm) in zip(range(1, 11), data):
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            for px, py in lm.split(2):
                plt.scatter(x=px, y=py, c='#1f77b4')
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif len(data.target_type) == 4:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (im, (_, lab, (x, y, w, h), lm)), axis in zip(data, axes.ravel()):
            axis.imshow(X=im)
            axis.set_title(label=lab.item())
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none', clip_on=True)
            axis.add_patch(p=rect)
# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
            axis.autoscale(enable=False) # This is important otherwise
                                         # the images are shrinked
            for px, py in lm.split(2):
                axis.scatter(x=px, y=py, c='#1f77b4')
# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
            # You can also use it
            # for px, py in lm.split(2):
            #     axis.add_patch(p=Circle(xy=(px, py)))
# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
            # You can also use it
            # axis.autoscale(enable=False) # This is important otherwise
            #                              # the images are shrinked
            # px = []
            # py = []
            # for j, v in enumerate(lm):
            #     if j%2 == 0:
            #         px.append(v)
            #     else:
            #         py.append(v)
            # axis.plot(px, py)
# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
        fig.tight_layout(h_pad=3.0)
        plt.show()
show_images(data=train_attr_data, main_title="train_attr_data")
show_images(data=valid_identity_data, main_title="valid_identity_data")
show_images(data=test_bbox_data, main_title="test_bbox_data")
show_images(data=all_landmarks_data, main_title="all_landmarks_data")
show_images(data=all_empty_data, main_title="all_empty_data")
show_images(data=all_all_data, main_title="all_all_data")
 







 
    
Top comments (0)