DEV Community

Cover image for What are GAN's and how to use them?
Aman Gupta
Aman Gupta

Posted on

What are GAN's and how to use them?

This includes the explanation of GAN's and how are they different from supervised ML algorithms. And how to implement them using TensorFlow.

So what exactly are GAN's?

  • Let's start off with explaining what's the difference between normal supervised machine learning algorithm's and GAN's

  • A supervised machine learning algorithms starts off with an input which is fed into a model, now that model generates an output and that output is tested against a ground truth output. Based on the results we update the model and this goes on.

Image description

  • While on the other hand GAN's stand for Generative Adversarial Network, which pins two different model's against each other, this is where it gets the "Adversarial" part of it's name.

  • Among the two models one is Generator and another is Discriminator. To understand the core principal let's take an example of training a GAN for generating flowers. Usually these two models are CNN's which are amazing in identifying patterns in images.

Image description

  • Initially the Discriminator is fed with a a bunch of domain (real images) of flowers and is trained such that it can understand what a flower looks like, and is able to discriminate between a read flower and another object.

  • After this the Generator is given a random input and asked to generate a fake flower, which is then fed to Discriminator and tested if it's a fake or not. This is a game of one winner, the winner gets to stay and maintain it's weights.

  • If the Discriminator correctly identifies the flower as fake, then it gets to remain the same while the Generator has to update it's weights. And if the Discriminator can't identify the fake flower then it has to update it's weights.

  • This cycle keeps on going till our Generator get's good enough to fool the Discriminator and can generate real enough images of flowers.

  • This approach has a lot of use cases like:

  1. Generate photorealistic images: Portraits, landscapes, objects, clothing, product prototypes, avatars.

  2. Fill in missing parts of images: Restore old photos, repair artistic works.

  3. Enhance low-quality images: Sharpen blurry videos, zoom in without losing detail.

  4. Apply artistic styles: Transfer styles between images, create personal art, conceptual design.

  5. Generate from text descriptions: Paintings, landscapes from text, personalized artwork.

  6. Predict future frames in videos: Complete videos, create visual effects.

Now lets code them up

Setup:

import tensorflow as tf

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

# To generate GIFs
pip install imageio
pip install git+https://github.com/tensorflow/docs

from IPython import display
Enter fullscreen mode Exit fullscreen mode

Load and prepare dataset:

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]

BUFFER_SIZE = 60000
BATCH_SIZE = 256

# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
Enter fullscreen mode Exit fullscreen mode

Create the models:

  • The Generator: The generator uses 'tf.keras.layers.Conv2DTranspose' (upsampling) layers to produce an image from a seed (random noise). Start with a Dense layer that takes this seed as input, then upsample several times until you reach the desired image size of 28x28x1. Notice the 'tf.keras.layers.LeakyReLU' activation for each layer, except the output layer which uses 'tanh'.
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model
Enter fullscreen mode Exit fullscreen mode
generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
Enter fullscreen mode Exit fullscreen mode

The Discriminator: The discriminator is a CNN-based image classifier.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model
Enter fullscreen mode Exit fullscreen mode
  • Use the (as yet untrained) discriminator to classify the generated images as real or fake. The model will be trained to output positive values for real images, and negative values for fake images.

Define the loss optimizers:

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
Enter fullscreen mode Exit fullscreen mode

Defining the training loop:

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Enter fullscreen mode Exit fullscreen mode

Generate and save images:

def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)
Enter fullscreen mode Exit fullscreen mode

Train the model:

Call the train() method defined above to train the generator and discriminator simultaneously. Note, training GANs can be tricky. It's important that the generator and discriminator do not overpower each other (e.g., that they train at a similar rate).

At the beginning of the training, the generated images look like random noise. As training progresses, the generated digits will look increasingly real. After about 50 epochs, they resemble MNIST digits. This may take about one minute / epoch with the default settings on Colab.

train(train_dataset, EPOCHS)
Enter fullscreen mode Exit fullscreen mode

Create the gif:

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))

display_image(EPOCHS)

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)
Enter fullscreen mode Exit fullscreen mode

Thank you for reading :)

Top comments (0)