- Neural Networks
- Keras (better to have)
GAN is the technology in the field of Neural Network innovated by Ian Goodfellow and his friends. SRGAN is the method by which we can increase the resolution of any image.
It contains basically two parts Generator and Discriminator. Generator produces refined output data from given input noise. Discriminator receives two types of data: one is the real world data and another is the generated output from generator. For discriminator, real data has label ‘1’ and generated data has label ‘0’. We can take the analogy of generator as artist and discriminator as critic. Artists create an art form which is judged by the critic.
As the generator improves with training, the discriminator performance gets worse because the discriminator can’t easily tell the difference between real and fake. Theoretically, at last discriminator will have 50% accuracy just like flip of a coin.
So our motto is to decrease the accuracy of the people who judge us and focus on our artwork.
Structure of SRGAN
The generator and discriminator are trained differently. First discriminator is trained for one or more epochs and generator is also trained for one or more epochs then one cycle is said to be completed. Pretrained VGG19 model is used to extract features from the image while training.
While training the generator the parameters of discriminator are frozen or else the model would be hitting a moving target and never converges.
Import necessary dependencies
import numpy as np from keras import Model from keras.layers import Conv2D, PReLU, BatchNormalization, Flatten from keras.layers import UpSampling2D, LeakyReLU, Dense, Input, add
Some of necessary variables
lr_ip = Input(shape=(25,25,3)) hr_ip = Input(shape=(100,100,3)) train_lr,train_hr = #training images arrays normalized between 0 & 1 test_lr, test_hr = # testing images arrays normalized between 0 & 1
We have to define a function to return the generator model which is used to produce the high resolution image. Residual block is the function in which returns a the addition of input layer and the final layer.
# Residual block def res_block(ip): res_model = Conv2D(64, (3,3), padding = "same")(ip) res_model = BatchNormalization(momentum = 0.5)(res_model) res_model = PReLU(shared_axes = [1,2])(res_model) res_model = Conv2D(64, (3,3), padding = "same")(res_model) res_model = BatchNormalization(momentum = 0.5)(res_model) return add([ip,res_model]) # Upscale the image 2x def upscale_block(ip): up_model = Conv2D(256, (3,3), padding="same")(ip) up_model = UpSampling2D( size = 2 )(up_model) up_model = PReLU(shared_axes=[1,2])(up_model) return up_model num_res_block = 16 # Generator Model def create_gen(gen_ip): layers = Conv2D(64, (9,9), padding="same")(gen_ip) layers = PReLU(shared_axes=[1,2])(layers) temp = layers for i in range(num_res_block): layers = res_block(layers) layers = Conv2D(64, (3,3), padding="same")(layers) layers = BatchNormalization(momentum=0.5)(layers) layers = add([layers,temp]) layers = upscale_block(layers) layers = upscale_block(layers) op = Conv2D(3, (9,9), padding="same")(layers) return Model(inputs=gen_ip, outputs=op)
This block of code defines the structure of discriminator model, and all of the layers involved to distinguish real and generated image. As we go deeper, after each 2 layers the number of filter increases by twice.
#Small block inside the discriminator def discriminator_block(ip, filters, strides=1, bn=True): disc_model = Conv2D(filters, (3,3), strides, padding="same")(ip) disc_model = LeakyReLU( alpha=0.2 )(disc_model) if bn: disc_model = BatchNormalization( momentum=0.8 )(disc_model) return disc_model # Discriminator Model def create_disc(disc_ip): df = 64 d1 = discriminator_block(disc_ip, df, bn=False) d2 = discriminator_block(d1, df, strides=2) d3 = discriminator_block(d2, df*2) d4 = discriminator_block(d3, df*2, strides=2) d5 = discriminator_block(d4, df*4) d6 = discriminator_block(d5, df*4, strides=2) d7 = discriminator_block(d6, df*8) d8 = discriminator_block(d7, df*8, strides=2) d8_5 = Flatten()(d8) d9 = Dense(df*16)(d8_5) d10 = LeakyReLU(alpha=0.2)(d9) validity = Dense(1, activation='sigmoid')(d10) return Model(disc_ip, validity)
In this code block, we use the VGG19 model trained with image-net database to extract the features, this model is frozen later so that parameters won’t get updated.
from keras.applications import VGG19 # Build the VGG19 model upto 10th layer # Used to extract the features of high res imgaes def build_vgg(): vgg = VGG19(weights="imagenet") vgg.outputs = [vgg.layers.output] img = Input(shape=hr_shape) img_features = vgg(img) return Model(img, img_features)
Now, we attach both generator and discriminator model. The model obtained from this is used only to train the generator model. While training this combined model we have to freeze the discriminator in each epoch.
# Attach the generator and discriminator def create_comb(gen_model, disc_model, vgg, lr_ip, hr_ip): gen_img = gen_model(lr_ip) gen_features = vgg(gen_img) disc_model.trainable = False validity = disc_model(gen_img) return Model([lr_ip, hr_ip],[validity,gen_features])
Then, we declare generator, discriminator and vgg models. Those model will be used as arguments for the combined model.
Any changes of the smaller models inside the combined model also affects the model outside. For example: weight updates, freezing the model, etc.
generator = create_gen(lr_ip) discriminator = create_disc(hr_ip) discriminator.compile(loss="binary_crossentropy", optimizer="adam", metrics=['accuracy']) vgg = build_vgg() vgg.trainable = False gan_model = create_comb(generator, discriminator, vgg, lr_ip, hr_ip) gan_model.compile(loss=["binary_crossentropy","mse"], loss_weights= [1e-3, 1], optimizer="adam")
Sample the training data into small batches
As the training set is too large, we need to sample the images into small batches to avoid Resource Exhausted Error. The resource such as RAM will not be enough to train all the images at once.
batch_size = 20 train_lr_batches =  train_hr_batches =  for it in range(int(train_hr.shape / batch_size)): start_idx = it * batch_size end_idx = start_idx + batch_size train_hr_batches.append(train_hr[start_idx:end_idx]) train_lr_batches.append(train_lr[start_idx:end_idx]) train_lr_batches = np.array(train_lr_batches) train_hr_batches = np.array(train_hr_batches)
Training the model
This block is the core of whole program. Here we train the discriminator and generator in the alternating method as mentioned above. As of now, the discriminator is frozen, do not forget to unfreeze before and freeze after training the discriminator, which is given in the code below.
epochs = 100 for e in range(epochs): gen_label = np.zeros((batch_size, 1)) real_label = np.ones((batch_size,1)) g_losses =  d_losses =  for b in range(len(train_hr_batches)): lr_imgs = train_lr_batches[b] hr_imgs = train_hr_batches[b] gen_imgs = generator.predict_on_batch(lr_imgs) #Dont forget to make the discriminator trainable discriminator.trainable = True #Train the discriminator d_loss_gen = discriminator.train_on_batch(gen_imgs, gen_label) d_loss_real = discriminator.train_on_batch(hr_imgs, real_label) discriminator.trainable = False d_loss = 0.5 * np.add(d_loss_gen, d_loss_real) image_features = vgg.predict(hr_imgs) #Train the generator g_loss, _, _ = gan_model.train_on_batch([lr_imgs, hr_imgs], [real_label, image_features]) d_losses.append(d_loss) g_losses.append(g_loss) g_losses = np.array(g_losses) d_losses = np.array(d_losses) g_loss = np.sum(g_losses, axis=0) / len(g_losses) d_loss = np.sum(d_losses, axis=0) / len(d_losses) print("epoch:", e+1 ,"g_loss:", g_loss, "d_loss:", d_loss)
Evaluate the model
Hereby, we calculate the performance of the generator with test dataset. The loss may be a little larger than with training dataset, but do not worry as long as long as the difference is small.
label = np.ones((len(test_lr),1)) test_features = vgg.predict(test_hr) eval,_,_ = gan_model.evaluate([test_lr, test_hr], [label,test_features])
Predict the output
We can generate high resolution images with generator model.
test_prediction = generator.predict_on_batch(test_lr)
You can find my implementation which was trained on google colab in my github profile.
- Always remember which model to make trainable or not.
- While training the generator use the label value as one.
- It is better to use images larger than 25x25 as they have more details for generated images.
- Do not forget to normalize the numpy dataset between 0 and 1.
Jason Brownlee. 2019. Generative Adversarial Networks with Python
https://arxiv.org/pdf/1609.04802. Paper on SRGAN