DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

MovingMNIST in PyTorch

Buy Me a Coffee

*Memos:

MovingMNIST() can use Moving MNIST dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Optional-Default:None-Type:str): *Memos:
    • None, "train" or "test" can be set to it.
    • If it's None, all 20 frames(images) of each video are returned, ignoring split_ratio.
  • The 3rd argument is split_ratio(Optional-Default:10-Type:int): *Memos:
    • If split is "train", data[:, :split_ratio] is returned.
    • If split is "test", data[:, split_ratio:] is returned.
    • If split is None, it's ignored. ignoring split_ratio.
  • The 4th argument is transform(Optional-Default:None-Type:callable).
  • The 5th argument is download(Optional-Default:False-Type:bool): *Memos:
    • If it's True, the dataset is downloaded from the internet to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded, nothing happens.
    • It should be False if the dataset is already downloaded because it's faster.
    • You can manually download and extract the dataset(mnist_test_seq.npy) from here to data/MovingMNIST/.
from torchvision.datasets import MovingMNIST

all_data = MovingMNIST(
    root="data"
)

all_data = MovingMNIST(
    root="data",
    split=None,
    split_ratio=10,
    download=False,
    transform=None
)

train_data = MovingMNIST(
    root="data",
    split="train"
)

test_data = MovingMNIST(
    root="data",
    split="test"
)

len(all_data), len(train_data), len(test_data)
# (10000, 10000, 10000)

len(all_data[0]), len(train_data[0]), len(test_data[0])
# (20, 10, 10)

all_data
# Dataset MovingMNIST
#     Number of datapoints: 10000
#     Root location: data

all_data.root
# 'data'

print(all_data.split)
# None

all_data.split_ratio
# 10

all_data.download
# <bound method MovingMNIST.download of Dataset MovingMNIST
#     Number of datapoints: 10000
#     Root location: data>

print(all_data.transform)
# None

all_data[0]
# tensor([[[[0, 0, 0,  ..., 0, 0, 0],
#           ...,
#           [0, 0, 0,  ..., 0, 0, 0]]],
#         ...
#         [[[0, 0, 0,  ..., 0, 0, 0],
#           ...,
#           [0, 0, 0,  ..., 0, 0, 0]]]], dtype=torch.uint8)

all_data[1]
# tensor([[[[0, 0, 0,  ..., 0, 0, 0],
#           ...,
#           [0, 0, 0,  ..., 0, 0, 0]]],
#         ...
#         [[[0, 0, 0,  ..., 0, 0, 0],
#           ...,
#           [0, 0, 0,  ..., 0, 0, 0]]]], dtype=torch.uint8)

all_data[2]
# tensor([[[[0, 0, 0,  ..., 0, 0, 0],
#           ...,
#           [0, 0, 0,  ..., 0, 0, 0]]],
#         ...
#         [[[0, 0, 0,  ..., 0, 0, 0],
#           ...,
#           [0, 0, 0,  ..., 0, 0, 0]]]], dtype=torch.uint8)

import matplotlib.pyplot as plt

def show_images(data, labs):
    plt.figure(figsize=(8, 4))
    for i, (vid, lab) in enumerate(iterable=zip(data, labs), start=1):
        plt.subplot(1, 3, i)
        plt.imshow(X=vid.squeeze()[0])
        plt.title(label=lab)
    plt.tight_layout()
    plt.show()

videos = (all_data[0], train_data[0], test_data[0])
titles = ("all_data", "train_data", "test_data")

show_images(data=videos, labs=titles)
Enter fullscreen mode Exit fullscreen mode

Image description

from torchvision.datasets import MovingMNIST

all_data = MovingMNIST(
    root="data",
    split=None
)

train_data = MovingMNIST(
    root="data",
    split="train"
)

test_data = MovingMNIST(
    root="data",
    split="test"
)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(12, 10))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, im in enumerate(iterable=data.squeeze(), start=1):
        plt.subplot(4, 5, i)
        plt.title(label=i)
        plt.imshow(X=im)
    plt.tight_layout()
    plt.show()

show_images(data=all_data[0], main_title="all_data")
show_images(data=train_data[0], main_title="train_data")
show_images(data=test_data[0], main_title="test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

from torchvision.datasets import MovingMNIST

all_data = MovingMNIST(
    root="data",
    split=None
)

train_data = MovingMNIST(
    root="data",
    split="train"
)

test_data = MovingMNIST(
    root="data",
    split="test"
)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(10, 8))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, vid in zip(range(1, 6), data):
        plt.subplot(4, 5, i)
        plt.title(label=i)
        plt.imshow(X=vid.squeeze()[0])
    plt.tight_layout()
    plt.show()

show_images(data=all_data, main_title="all_data")
show_images(data=train_data, main_title="train_data")
show_images(data=test_data, main_title="test_data")
Enter fullscreen mode Exit fullscreen mode

Image description

from torchvision.datasets import MovingMNIST
import matplotlib.animation as animation

all_data = MovingMNIST(
    root="data"
)

import matplotlib.pyplot as plt
from IPython.display import HTML

figure, axis = plt.subplots()

# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ `ArtistAnimation()` ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
ims = []
for im in all_data[0].squeeze():
    ims.append([axis.imshow(X=im)])
ani = animation.ArtistAnimation(fig=figure, artists=ims,
                                interval=100)
# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ `ArtistAnimation()` ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑

# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ `FuncAnimation()` ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
# def animate(i):
#     axis.imshow(X=all_data[0].squeeze()[i])
#
# ani = animation.FuncAnimation(fig=figure, func=animate,
#                               frames=20, interval=100)
# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ `FuncAnimation()` ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑

# ani.save('result.gif') # Save the animation as a `.gif` file

plt.ioff() # Hide a useless image

# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ Show animation ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
HTML(ani.to_jshtml()) # Animation operator
# HTML(ani.to_html5_video()) # Animation video
# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ Show animation ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑

# ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ Show animation ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓ ↓
# plt.rcParams["animation.html"] = "jshtml" # Animation operator
# plt.rcParams["animation.html"] = "html5" # Animation video
# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ Show animation ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

from torchvision.datasets import MovingMNIST
from ipywidgets import interact, IntSlider

all_data = MovingMNIST(
    root="data"
)

import matplotlib.pyplot as plt
from IPython.display import HTML

def func(i):
    plt.imshow(X=all_data[0].squeeze()[i])

interact(func, i=(0, 19, 1))
# interact(func, i=IntSlider(min=0, max=19, step=1, value=0))
# ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ Set the start value ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑ ↑
plt.show()
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image of Timescale

🚀 pgai Vectorizer: SQLAlchemy and LiteLLM Make Vector Search Simple

We built pgai Vectorizer to simplify embedding management for AI applications—without needing a separate database or complex infrastructure. Since launch, developers have created over 3,000 vectorizers on Timescale Cloud, with many more self-hosted.

Read more

Top comments (0)

Postmark Image

Speedy emails, satisfied customers

Are delayed transactional emails costing you user satisfaction? Postmark delivers your emails almost instantly, keeping your customers happy and connected.

Sign up