DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

1

AugMix in PyTorch (8)

Buy Me a Coffee

*Memos:

AugMix() can randomly do AugMix to an image as shown below. *It's about chain_depth argument (1):

from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import AugMix
from torchvision.transforms.functional import InterpolationMode

origin_data = OxfordIIITPet(
    root="data",
    transform=None
)

cd0_data = OxfordIIITPet( # `cd` is chain_depth.
    root="data",
    transform=AugMix(chain_depth=0)
)

cd1_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=1)
)

cd2_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=2)
)

cd5_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=5)
)

cd10_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=10)
)

cd25_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=25)
)

cd50_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=50)
)

cdn1_data = OxfordIIITPet( # `n` is negative.
    root="data",
    transform=AugMix(chain_depth=-1)
)

cdn2_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=-2)
)

cdn5_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=-5)
)

cdn10_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=-10)
)

cdn25_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=-25)
)

cdn50_data = OxfordIIITPet(
    root="data",
    transform=AugMix(chain_depth=-50)
)

import matplotlib.pyplot as plt

def show_images1(data, main_title=None):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    for i, (im, _) in zip(range(1, 6), data):
        plt.subplot(1, 5, i)
        plt.imshow(X=im)
        plt.xticks(ticks=[])
        plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images1(data=origin_data, main_title="origin_data")
print()
show_images1(data=cd0_data, main_title="cd0_data")
show_images1(data=cd1_data, main_title="cd1_data")
show_images1(data=cd2_data, main_title="cd2_data")
show_images1(data=cd5_data, main_title="cd5_data")
show_images1(data=cd10_data, main_title="cd10_data")
show_images1(data=cd25_data, main_title="cd25_data")
show_images1(data=cd50_data, main_title="cd50_data")
print()
show_images1(data=cd0_data, main_title="cd0_data")
show_images1(data=cdn1_data, main_title="cdn1_data")
show_images1(data=cdn2_data, main_title="cdn2_data")
show_images1(data=cdn5_data, main_title="cdn5_data")
show_images1(data=cdn10_data, main_title="cdn10_data")
show_images1(data=cdn25_data, main_title="cdn25_data")
show_images1(data=cdn50_data, main_title="cdn50_data")

# ↓ ↓ ↓ ↓ ↓ ↓ The code below is identical to the code above. ↓ ↓ ↓ ↓ ↓ ↓
def show_images2(data, main_title=None, s=3, mw=3, cd=-1, a=1.0,
                 ao=True, ip=InterpolationMode.BILINEAR, f=None):
    plt.figure(figsize=[10, 5])
    plt.suptitle(t=main_title, y=0.8, fontsize=14)
    if main_title != "origin_data":
        for i, (im, _) in zip(range(1, 6), data):
            plt.subplot(1, 5, i)
            am = AugMix(severity=s, mixture_width=mw, chain_depth=cd,
                        alpha=a, all_ops=ao, interpolation=ip, fill=f)
            plt.imshow(X=am(im))
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
    else:
        for i, (im, _) in zip(range(1, 6), data):
            plt.subplot(1, 5, i)
            plt.imshow(X=im)
            plt.xticks(ticks=[])
            plt.yticks(ticks=[])
    plt.tight_layout()
    plt.show()

show_images2(data=origin_data, main_title="origin_data")
print()
show_images2(data=origin_data, main_title="cd0_data", cd=0)
show_images2(data=origin_data, main_title="cd1_data", cd=1)
show_images2(data=origin_data, main_title="cd2_data", cd=2)
show_images2(data=origin_data, main_title="cd5_data", cd=5)
show_images2(data=origin_data, main_title="cd10_data", cd=10)
show_images2(data=origin_data, main_title="cd25_data", cd=25)
show_images2(data=origin_data, main_title="cd50_data", cd=50)
print()
show_images2(data=origin_data, main_title="cd0_data", cd=0)
show_images2(data=origin_data, main_title="cdn1_data", cd=-1)
show_images2(data=origin_data, main_title="cdn2_data", cd=-2)
show_images2(data=origin_data, main_title="cdn5_data", cd=-5)
show_images2(data=origin_data, main_title="cdn10_data", cd=-10)
show_images2(data=origin_data, main_title="cdn25_data", cd=-25)
show_images2(data=origin_data, main_title="cdn50_data", cd=-50)
Enter fullscreen mode Exit fullscreen mode

Image description


Image description

Image description

Image description

Image description

Image description

Image description

Image description


Image description

Image description

Image description

Image description

Image description

Image description

Image description

Heroku

Amplify your impact where it matters most — building exceptional apps.

Leave the infrastructure headaches to us, while you focus on pushing boundaries, realizing your vision, and making a lasting impression on your users.

Get Started

Top comments (1)

Collapse
 
champsoft profile image
champsoft

I’m interested in this. keep it up

AWS Q Developer image

Your AI Code Assistant

Automate your code reviews. Catch bugs before your coworkers. Fix security issues in your code. Built to handle large projects, Amazon Q Developer works alongside you from idea to production code.

Get started free in your IDE

👋 Kindness is contagious

Explore a trove of insights in this engaging article, celebrated within our welcoming DEV Community. Developers from every background are invited to join and enhance our shared wisdom.

A genuine "thank you" can truly uplift someone’s day. Feel free to express your gratitude in the comments below!

On DEV, our collective exchange of knowledge lightens the road ahead and strengthens our community bonds. Found something valuable here? A small thank you to the author can make a big difference.

Okay