DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

InterpolationMode in PyTorch (3)

Buy Me a Coffee

*Memos:

PyTorch's NEAREST matches the OpenCV's INTER_NEAREST which is buggy as shown below. *It's about NEAREST and NEAREST_EXACT:

from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import PILToTensor, Resize
import numpy as np
import cv2

ptt = PILToTensor()

pytorchimagetensor = ptt(origin_data[0][0])

r = Resize(size=[50, 50], interpolation=InterpolationMode.NEAREST)

pytorch_resize = r(pytorchimagetensor).permute(dims=[1, 2, 0])
pytorch_resize
# tensor([[[37, 20, 12],
#          [36, 17, 10],
#          ...,
#          [252, 250, 253]],
#         [[36, 15, 14],
#          [24, 8, 9],
#          ...,
#          [255, 255, 255]],
#         [[255, 255, 196],
#          [253, 255, 206],
#          ...,
#          [255, 255, 255]],
#         ...,
#         [[14, 16, 54],
#          [12, 14, 52],
#          ...,
#          [254, 254, 254]],
#         [[8, 11, 44],
#          [11, 16, 46],
#          ...,
#          [255, 254, 255]],
#         [[4, 9, 31],
#          [0, 2, 0],
#          ...,
#          [57, 109, 231]]], dtype=torch.uint8)

numpyimagearray = np.array(object=origin_data[0][0])

opencv_resize = cv2.resize(src=numpyimagearray, dsize=[50, 50], 
                           interpolation=cv2.INTER_NEAREST)
opencv_resize
# array([[[37, 20, 12],
#         [36, 17, 10],
#         ...,
#         [252, 250, 253]],
#        [[36, 15, 14],
#         [24, 8, 9],
#         ...,
#         [255, 255, 255]],
#        [[255, 255, 196],
#         [253, 255, 206],
#         ...,
#         [255, 255, 255]],
#        ...,
#        [[14, 16, 54],
#         [12, 14, 52],
#         ...,
#         [254, 254, 254]],
#        [[8, 11, 44],
#         [11, 16, 46],
#         ...,
#         [255, 254, 255]],
#        [[4, 9, 31],
#         [0, 2, 0],
#         ...,
#         [57, 109, 231]]], dtype=uint8)
Enter fullscreen mode Exit fullscreen mode

PyTorch's NEAREST_EXACT matches the Scikit-image's Nearest-neighbor and PIL's(Pillow's) NEAREST which aren't buggy:

from torchvision.datasets import OxfordIIITPet
from torchvision.transforms.v2 import PILToTensor, Resize
import numpy as np
from skimage.transform import resize

ptt = PILToTensor()

pytorchimagetensor = ptt(origin_data[0][0])

r = Resize(size=[50, 50], interpolation=InterpolationMode.NEAREST_EXACT)

pytorch_resize = r(pytorchimagetensor).permute(dims=[1, 2, 0])
pytorch_resize
# tensor([[[36, 19, 11],
#          [31, 20, 14],
#          ...,
#          [254, 254, 254]]
#         [[241, 230, 114],
#          [252, 251, 98],
#          ...,
#          [255, 255, 255]],
#         [[73, 158, 195],
#          [255, 253, 192],
#          ...,
#          [255, 255, 255]]
#         ...,
#         [[13, 13, 49],
#          [12, 12, 50],
#          ...,
#          [253, 253, 253]],
#         [[10, 14, 43],
#          [177, 176, 172],
#          ...,
#          [61, 90, 216]],
#         [[6, 8, 21],
#          [14, 16, 41],
#          ...,
#          [60, 103, 231]]], dtype=torch.uint8)

numpyimagearray = np.array(object=origin_data[0][0])

scikitimage_resize = resize(image=numpyimagearray, # `0` is Nearest-neighbor.
                            output_shape=[50, 50], order=0)
scikitimage_resize
# array([[[36, 19, 11],
#         [31, 20, 14],
#         ...,
#         [254, 254, 254]],
#        [[241, 230, 114],
#         [252, 251, 98],
#         ...,
#         [255, 255, 255]],
#        [[73, 158, 195],
#         [255, 253, 192],
#         ...,
#         [255, 255, 255]],
#        ...,
#        [[13, 13, 49],
#         [12, 12, 50],
#         ...,
#         [253, 253, 253]],
#        [[10, 14, 43],
#         [177, 176, 172],
#         ...,
#         [61, 90, 216]],
#        [[6, 8, 21],
#         [14, 16, 41],
#         ...,
#         [60, 103, 231]]], dtype=uint8)

PILimagearray = np.array(origin_data[0][0].resize(size=[50, 50],
                         resample=0)) # `0` is Nearest.
PILimagearray
# array([[[36, 19, 11],
#         [31, 20, 14],
#         ...,
#         [254, 254, 254]],
#        [[241, 230, 114],
#         [252, 251, 98],
#         ...,
#         [255, 255, 255]],
#        [[73, 158, 195],
#         [255, 253, 192],
#         ...,
#         [255, 255, 255]],
#        ...,
#        [[13, 13, 49],
#         [12, 12, 50],
#         ...,
#         [253, 253, 253]],
#        [[10, 14, 43],
#         [177, 176, 172],
#         ...,
#         [61, 90, 216]],
#        [[6, 8, 21],
#         [14, 16, 41],
#         ...,
#         [60, 103, 231]]], dtype=uint8)

ptt = PILToTensor()

PILimagetensor = ptt(origin_data[0][0].resize(size=[50, 50], 
                     resample=0)).permute(dims=[1, 2, 0])
PILimagetensor       # `0` is Nearest.
# tensor([[[36, 19, 11],
#          [31, 20, 14],
#          ...,
#          [254, 254, 254]],
#         [[241, 230, 114],
#          [252, 251, 98],
#          ...,
#          [255, 255, 255]],
#         [[73, 158, 195],
#          [255, 253, 192],
#          ...,
#          [255, 255, 255]],
#         ...,
#         [[13, 13, 49],
#          [12, 12, 50],
#          ...,
#          [253, 253, 253]],
#         [[10, 14, 43],
#          [177, 176, 172],
#          ...,
#          [61, 90, 216]],
#         [[6, 8, 21],
#          [14, 16, 41],
#          ...,
#          [60, 103, 231]]], dtype=torch.uint8)
Enter fullscreen mode Exit fullscreen mode

Top comments (0)