DEV Community

Tutor
Tutor

Posted on

การลบ Noise ในรูปภาพด้วยหลักการ Autoencoder โดยใช้ Python

การถ่ายภาพต่างๆด้วยกล้องดิจิตอลนั้นมักจะทำให้เกิด Noise ขึ้นเป็นเรื่องปกติ เนื่องจากกระแสไฟฟ้าที่ไหลเวียนอยู่ในตัวกล้องเอง โดยเฉพาะเวลาที่ถ่ายในสภาพที่แสงน้อย และต้องปรับค่า ISO ให้สูงพื่อให้กล้องรับแสงได้ไวและมากขึ้นเท่าไหร่ จะยิ่งทำให้ปรากฏ Noise ได้ชัดเจนมากขึ้นเท่านั้น

ในบทความนี้ เราจะใช้หลักการ Autoencoder ในการลบ Noise ที่อยู่ในภาพ เพื่อกู้คืนภาพที่มี Noise มาให้ใกล้เคียงกับภาพต้นฉบับมากที่สุด เราจะใช้ Google Colab ในการรันโค้ด โดย dataset เราจะใช้เป็นภาพตัวเลขที่ยังไม่มี Noise

ขั้นตอนที่ 1 นำเข้าข้อมูล

ตัวอย่างของข้อมูลสามารถ copy code และโหลดได้ตามนี้เลย

import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt

# Transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()

# Download the MNIST Dataset
dataset = datasets.MNIST(root = "./data",
                         train = True,
                         download = True,
                         transform = tensor_transform)

# DataLoader is used to load the dataset
# for training
loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size = 32,
                                     shuffle = True)
Enter fullscreen mode Exit fullscreen mode

ตัวอย่างผลที่ได้จาก code
Image description

ขั้นตอนที่ 2 Setup Model

สร้าง class ที่ใช้สำหรับการ Autoencoder ซึ่งหลักการทำงานคือ นำภาพต้นฉบับมาลดขนาดให้เล็กลงด้วยการ Encode และ Decode ให้ภาพกลับมาเท่าเดิม โดยต้องการให้ model เก็บข้อมูลที่สำคัญของภาพไว้

# Creating a PyTorch class
# 28*28 ==> 9 ==> 28*28
class AE(torch.nn.Module):
    def __init__(self):
        super().__init__()

        # Building an linear encoder with Linear
        # layer followed by Relu activation function
        # 784 ==> 9
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28 * 28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 3)
        )

        # Building an linear decoder with Linear
        # layer followed by Relu activation function
        # The Sigmoid activation function
        # outputs the value between 0 and 1
        # 9 ==> 784
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(3, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28 * 28),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
Enter fullscreen mode Exit fullscreen mode

Setup model ของ AI ที่จะใช้

# Model Initialization
model = AE().to(device)

# Validation using MSE Loss function
loss_function = torch.nn.MSELoss()

# Using an Adam Optimizer with lr = 0.001
optimizer = torch.optim.Adam(model.parameters(),
                             lr = 1e-3,
                             weight_decay = 1e-8)
Enter fullscreen mode Exit fullscreen mode

ทดลองเปรียบเทียบความแตกต่างระหว่างผลลัพธ์ที่ model ทำนาย กับภาพต้นฉบับ

from tqdm.notebook import tqdm

epochs = 10
outputs = []
losses = []
epoch_losses = []
for epoch in tqdm(range(epochs)):
    for (image, _) in loader:

      # Reshaping the image to (-1, 784)
      image = image.reshape(-1, 28*28)

      # Output of Autoencoder
      reconstructed = model(image.to(device))

      # Calculating the loss function
      loss = loss_function(reconstructed, image.to(device))

      # The gradients are set to zero,
      # the gradient is computed and stored.
      # .step() performs parameter update
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      # Storing the losses in a list for plotting
      losses.append(loss.cpu())
      epoch_losses.append(loss)
    print(f"epoch {epoch}: loss = {sum(epoch_losses)/len(epoch_losses)}")
    epoch_losses=[]
    outputs.append((epochs, image, reconstructed))
Enter fullscreen mode Exit fullscreen mode

ตัวอย่างผลที่ได้จาก code
Image description

แสดงผลด้วยกราฟ

# Defining the Plot Style
plt.style.use('fivethirtyeight')
plt.xlabel('Iterations')
plt.ylabel('Loss')

# Plotting the last 100 values
plt.plot(torch.tensor(losses).cpu().detach().numpy())
Enter fullscreen mode Exit fullscreen mode

Image description

แสดงผลเปรียบเทียบรูปภาพต้นฉบับและภาพที่ได้หลังจากทำการ Encode และ Decode เรียบร้อยแล้ว

for i, item in enumerate(image):
  # Reshape the array for plotting
  item = item.reshape(-1, 28, 28)
  plt.imshow(item[0])
  plt.show()
  break

for i, item in enumerate(reconstructed):
  item = item.reshape(-1, 28, 28)
  plt.imshow(item[0].cpu().detach().numpy())
  plt.show()
  break
Enter fullscreen mode Exit fullscreen mode

ภาพต้นฉบับ
Image description
ภาพที่ได้หลังจากทำการ Encode และ Decode
Image description

ขั้นตอนที่ 3 เตรียมข้อมูลในการฝึก Model

ทดลองเพิ่ม Noise ลงในรูปภาพต้นฉบับ

from skimage.util import random_noise

with torch.no_grad():
  for (image, _) in loader:
    print("Before adding noise")
    image_show = image[0].reshape(-1, 28, 28)
    noisy_img = random_noise(image_show, mode='gaussian')
    plt.imshow(image_show[0], cmap='gray')
    plt.show()
    print("After adding noise")
    plt.imshow(noisy_img[0], cmap='gray')
    plt.show()
    break
Enter fullscreen mode Exit fullscreen mode

ผลที่ได้
Before adding noise
Image description
After adding noise
Image description

ขั้นตอนที่ 4 ฝึก Model

ทำการเพิ่ม Noise ให้กับรูปภาพต้นฉบับ จากนั้นนำไป Encode ให้มีขนาดเล็กลง แล้ว Decode กลับมา เพื่อเปรียบเทียบความแตกต่างระหว่างภาพต้นฉบับที่ยังไม่มี Noise กับภาพที่สร้างขึ้นมาใหม่ด้วย Autoencoder

epochs = 10
outputs = []
losses = []
epoch_losses = []
for epoch in tqdm(range(epochs)):
    for (image, _) in loader:

      # Reshaping the image to (-1, 784)
      image = image.reshape(-1, 28*28)
      noisy_image = random_noise(image, mode='gaussian')

      # Output of Autoencoder
      reconstructed = model(torch.from_numpy(noisy_image).float().to(device))

      # Calculating the loss function
      loss = loss_function(reconstructed, image.to(device))

      # The gradients are set to zero,
      # the gradient is computed and stored.
      # .step() performs parameter update
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      # Storing the losses in a list for plotting
      losses.append(loss.cpu())
      epoch_losses.append(loss)
    print(f"epoch {epoch}: loss = {sum(epoch_losses)/len(epoch_losses)}")
    epoch_losses=[]
    outputs.append((epochs, image, reconstructed))
Enter fullscreen mode Exit fullscreen mode

ผลลัพธ์การการฝึก model
Image description

ขั้นตอนที่ 5 ทดสอบ Model

แสดงผลลัพธ์เปรียบเทียบระหว่างรูปที่มี Noise และรูปที่ผ่าน Autoencoder

#โชว์รูปก่อนทำนายละหลังทำนาย
with torch.no_grad():
  for (image, _) in loader:
    print("Input image + noise:")
    input_image = image[0].reshape(-1, 28*28)
    noisy_img = random_noise(input_image, mode='gaussian')
    plt.imshow(noisy_img.reshape(28, 28), cmap='gray')
    plt.show()

    latent_space = model.encoder(torch.from_numpy(noisy_img).float().to(device))
    print(f"latent_space: {latent_space}")
    after_decode = model.decoder(latent_space)
    after_decode_show = after_decode.reshape(28, 28)
    print("\nPredicted result:")
    plt.imshow(after_decode_show.cpu().detach().numpy(), cmap='gray')
    plt.show()
    break
Enter fullscreen mode Exit fullscreen mode

ผลที่ได้
Input image + noise:
Image description
latent_space: tensor([[1.3030, 4.0983, 2.9082]], device='cuda:0')

Predicted result:
Image description

สรุปผล

สิ่งที่เราต้องการจาก Encoder คือ ข้อมูลขนาดเล็กที่เก็บค่าที่สำคัญของรูปภาพ ยกเว้นค่า Noise จะเห็นว่าใน epoch แรกๆ จะมีค่า loss ของตัว Model ที่สูงและจะลดลงเรื่อยๆ ซึ่งในตัวอย่างมีการฝึก Model แค่ 10 epoch ถ้าหากต้องการให้ภาพที่ผ่าน Autoencoder มีความใกล้เคียงกับภาพต้นฉบับมากขึ้น การเพิ่ม epoch จะช่วยให้ภาพใกล้เคียงต้นฉบับมากขึ้นนั่นเอง

References
https://www.youtube.com/watch?v=f9ZeovxMaqw
https://keras.io/examples/vision/autoencoder/
https://blog.keras.io/building-autoencoders-in-keras.html
https://108-daily.blogspot.com/2019/04/what-is-noise.html

Top comments (0)