I was introduced to this splendid machine learning idea known as Generative Adversarial Networks (GANs) especially in the image generation area. Another framework known as GANs was developed by Ian Goodfellow in 2014; its underlying architecture is built by utilizing a two-neural-network competition. According to the scope of this blog, let me first introduce what GAN is, and then tell you what I am going to do in this blog including the code in TensorFlow about how to train a simple GAN.
What are GANs?
At its core, a GAN consists of two neural networks: Of course, there is the generator of the fake data, and the discriminator that learns how to distinguish between the fake and the real thing.
- Generator: After inputting noise and then passes them to produce an output data that resembles the pattern of the training data set.
- Discriminator: The discriminator employed in the description of this model takes an input sample and tries to afford a guess if the sample was drawn from the training data or was just synthesized with the help of the generator.
These two networks are trained simultaneously in a zero-sum game framework: while in GANs the generative network will feed information to the discriminative network in an effort to fool it into believing that the data fed to it is real but on the other side the discriminative network has the role of distinguishing real data from fake data.
Step-by-Step Guide to Building a Simple GAN
Step 1: Setting Up the Environment
pip install tensorflow
Step 2: Import Necessary Libraries
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
Step 3: Define the Generator
The generator network will next take a randomly chosen noise vector and map it into a data point that looks like the actual training data.
def build_generator():
model = tf.keras.Sequential()
model.add(layers.Dense(8*8*128, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((8, 8, 128)))
assert model.output_shape == (None, 8, 8, 128) # Note: None is the batch size
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 8, 8, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
assert model.output_shape == (None, 16, 16, 128)
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
assert model.output_shape == (None, 32, 32, 128)
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
assert model.output_shape == (None, 64, 64, 128)
model.add(layers.Conv2DTranspose(3, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
print(model.output_shape)
return model
generator = build_generator()
generator.summary()
Step 4: Define the Discriminator
The discriminator network will take an input sample and classify it as real
def build_discriminator():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[128, 128, 3]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
discriminator = build_discriminator()
discriminator.summary()
Step 5: Test the models
noise = tf.random.normal([1,100])
generated_image = generator(noise,training=False)
print(discriminator(generated_image))
plt.imshow(generated_image[0]*127.5+127.5)
Step 6: Setup loss function and optimizer
cross_entropy=BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output,fake_output):
real_loss = cross_entropy(tf.ones_like(real_output),real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output),fake_output)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
Step 7: Setup checkpoint
checkpoint_dir = 'training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir,'ckpt')
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
generator=generator,
discriminator=discriminator)
Step 8: Defining train step
@tf.function
def train_step(images):
noise=tf.random.normal([batch_size,noise_dims])
with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
generated_images=generator(noise,training=True)
real_output=discriminator(images,training=True)
fake_output=discriminator(generated_images,training=True)
gen_loss=generator_loss(fake_output)
disc_loss=discriminator_loss(real_output,fake_output)
gen_gradients=gen_tape.gradient(gen_loss,generator.trainable_variables)
dis_gradients=dis_tape.gradient(disc_loss,discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gen_gradients,generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(dis_gradients,discriminator.trainable_variables))
return gen_loss,disc_loss
Step 9: Setting up the training loop and saving generated images
from IPython import display
import time
total_gloss=[]
total_dloss=[]
def train(dataset,epochs):
for epoch in range(epochs):
disc_loss=gen_loss=0
start=time.time()
count=0
for batch in dataset:
losses=train_step(batch)
count+=1
disc_loss+=losses[1]
gen_loss+=losses[0]
total_gloss.append(gen_loss.numpy())
total_dloss.append(disc_loss.numpy())
if (epoch+1)%50==0:
checkpoint.save(file_prefix=checkpoint_prefix)
display.clear_output(wait=True)
generate_and_save_output(generator,epoch+1,seed)
print(f'Time for epoch {epoch + 1} is {time.time()-start}')
print(f'Gloss: {gen_loss.numpy()/count} , Dloss: {disc_loss.numpy()/count}',end='\n\n')
display.clear_output(wait=True)
generate_and_save_output(generator,epochs,seed)
def generate_and_save_output(model,epoch,test_input):
predictions = model(test_input,training=False)
fig = plt.figure(figsize=(4,4))
for i in range(predictions.shape[0]):
plt.subplot(4,4,i+1)
plt.imshow((predictions[i]*127.5+127.5).numpy().astype(np.uint8),cmap='gray')
plt.axis('off')
plt.savefig(f'image_at_epoch_{epoch}.png')
plt.show()
Step 10: Train the GAN
Let's train our GAN, I have used dog image dataset, which is available on Kaggle stanford dog dataset
EPOCHS = 500
noise_dims = 100
num_egs_to_generate = 16
seed = tf.random.normal([num_egs_to_generate,noise_dims])
train(train_images,EPOCHS)
Note: To generate good-quality images, the model would require large number of epochs.
Trying our model:
new_image = generator(tf.random.normal([1,100]),training=False)
plt.imshow((new_image[0]*127.5+127.5).numpy().astype(np.uint8))
Conclusion
GANs are useful in producing realistic datasets since they are a type of neural network that learns from the labeled training data and then creates new data. From here, it would be clear and feasible to formulate a sensible GAN and from here it is evident that there exists a relative rhythm between the motions of the generator as well as Discriminator. This distills the current guide’s aim to merely introduce the reader to the subject of GAN and offer them a first taste of what is possible in this burgeoning research area.
Resources:
Ian Goodfellow's Original Paper
TensorFlow Documentation
My Github Repo
Feel free to ask questions or share your GAN projects in the comments below!
Top comments (0)