<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/">
  <channel>
    <title>DEV Community: Manish Dhakal</title>
    <description>The latest articles on DEV Community by Manish Dhakal (@manishdhakal).</description>
    <link>https://dev.to/manishdhakal</link>
    <image>
      <url>https://media2.dev.to/dynamic/image/width=90,height=90,fit=cover,gravity=auto,format=auto/https:%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fuploads%2Fuser%2Fprofile_image%2F333616%2Fbfac1445-a7ac-49e3-96ab-294fb2e3cd59.jpeg</url>
      <title>DEV Community: Manish Dhakal</title>
      <link>https://dev.to/manishdhakal</link>
    </image>
    <atom:link rel="self" type="application/rss+xml" href="https://dev.to/feed/manishdhakal"/>
    <language>en</language>
    <item>
      <title>Super Resolution with GAN and Keras (SRGAN)</title>
      <dc:creator>Manish Dhakal</dc:creator>
      <pubDate>Thu, 04 Feb 2021 04:47:00 +0000</pubDate>
      <link>https://dev.to/manishdhakal/super-resolution-with-gan-and-keras-srgan-38ma</link>
      <guid>https://dev.to/manishdhakal/super-resolution-with-gan-and-keras-srgan-38ma</guid>
      <description>&lt;h2&gt;
  
  
  &lt;strong&gt;Prior Knowledge&lt;/strong&gt;
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;Neural Networks&lt;/li&gt;
&lt;li&gt;Python&lt;/li&gt;
&lt;li&gt;Keras (better to have)&lt;/li&gt;
&lt;/ul&gt;




&lt;h2&gt;
  
  
  &lt;strong&gt;Generative Adversarial Networks (GAN)&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;GAN is the technology in the field of Neural Network innovated by Ian Goodfellow and his friends. &lt;a href="https://arxiv.org/pdf/1609.04802.pdf" rel="noopener noreferrer"&gt;SRGAN&lt;/a&gt; is the method by which we can increase the resolution of any image.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Fhy9nezuvacl390hqhsfl.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Fhy9nezuvacl390hqhsfl.png" alt="GAN"&gt;&lt;/a&gt;&lt;br&gt;&lt;br&gt;
It contains basically two parts &lt;strong&gt;Generator&lt;/strong&gt; and &lt;strong&gt;Discriminator&lt;/strong&gt;. Generator produces refined output data from given input noise. Discriminator receives two types of data: one is the real world data and another is the generated output from generator. For discriminator, real data has label ‘1’ and generated data has label ‘0’. We can take the analogy of generator as &lt;strong&gt;artist&lt;/strong&gt; and discriminator as &lt;strong&gt;critic&lt;/strong&gt;. Artists create an art form which is judged by the critic.&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Fverfwzsdfcv4glmy67cd.jpeg" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Fverfwzsdfcv4glmy67cd.jpeg" alt="ARTIST AND CRITIC"&gt;&lt;/a&gt;&lt;br&gt;&lt;/p&gt;

&lt;p&gt;As the generator improves with training, the discriminator performance gets worse because the discriminator can’t easily tell the difference between real and fake. Theoretically, at last discriminator will have 50% accuracy just like flip of a coin.&lt;/p&gt;

&lt;p&gt;&lt;em&gt;So our motto is to decrease the accuracy of the people who judge us and focus on our artwork.&lt;/em&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Structure of SRGAN&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Flqoxdjym10z8306uzm2e.png" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2Flqoxdjym10z8306uzm2e.png" alt="SRGAN MODEL"&gt;&lt;/a&gt;&lt;br&gt;&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Alternate Training&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;The generator and discriminator are trained differently. First discriminator is trained for one or more epochs and generator is also trained for one or more epochs then one cycle is said to be completed. Pretrained VGG19 model is used to extract features from the image while training.&lt;br&gt;
While training the generator the parameters of discriminator are frozen or else the model would be hitting a moving target and never converges.&lt;/p&gt;


&lt;h2&gt;
  
  
  &lt;strong&gt;Code&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;Import necessary dependencies&lt;/strong&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;import numpy as np
from keras import Model
from keras.layers import Conv2D, PReLU, BatchNormalization, Flatten
from keras.layers import UpSampling2D, LeakyReLU, Dense, Input, add

&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;strong&gt;Some of necessary variables&lt;/strong&gt;&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;lr_ip = Input(shape=(25,25,3))
hr_ip = Input(shape=(100,100,3))
train_lr,train_hr = #training images arrays normalized between 0 &amp;amp; 1
test_lr, test_hr = # testing images arrays normalized between 0 &amp;amp; 1
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;br&gt;&lt;br&gt;
&lt;strong&gt;Define Generator&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;We have to define a function to return the generator model which is used to produce the high resolution image. Residual block is the function in which returns a the addition of input layer and the final layer.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;# Residual block
def res_block(ip):

    res_model = Conv2D(64, (3,3), padding = "same")(ip)
    res_model = BatchNormalization(momentum = 0.5)(res_model)
    res_model = PReLU(shared_axes = [1,2])(res_model)

    res_model = Conv2D(64, (3,3), padding = "same")(res_model)
    res_model = BatchNormalization(momentum = 0.5)(res_model)

    return add([ip,res_model])

# Upscale the image 2x
def upscale_block(ip):    
    up_model = Conv2D(256, (3,3), padding="same")(ip)
    up_model = UpSampling2D( size = 2 )(up_model)
    up_model = PReLU(shared_axes=[1,2])(up_model)

    return up_model
num_res_block = 16

# Generator Model
def create_gen(gen_ip):
    layers = Conv2D(64, (9,9), padding="same")(gen_ip)
    layers = PReLU(shared_axes=[1,2])(layers)
    temp = layers
    for i in range(num_res_block):
        layers = res_block(layers)
    layers = Conv2D(64, (3,3), padding="same")(layers)
    layers = BatchNormalization(momentum=0.5)(layers)
    layers = add([layers,temp])
    layers = upscale_block(layers)
    layers = upscale_block(layers)
    op = Conv2D(3, (9,9), padding="same")(layers)
    return Model(inputs=gen_ip, outputs=op)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;br&gt;&lt;br&gt;
&lt;strong&gt;Define Discriminator&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;This block of code defines the structure of discriminator model, and all of the layers involved to distinguish real and generated image. As we go deeper, after each 2 layers the number of filter increases by twice.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;#Small block inside the discriminator
def discriminator_block(ip, filters, strides=1, bn=True):

    disc_model = Conv2D(filters, (3,3), strides, padding="same")(ip)
    disc_model = LeakyReLU( alpha=0.2 )(disc_model)
    if bn:
        disc_model = BatchNormalization( momentum=0.8 )(disc_model)
    return disc_model

# Discriminator Model
def create_disc(disc_ip):
    df = 64

    d1 = discriminator_block(disc_ip, df, bn=False)
    d2 = discriminator_block(d1, df, strides=2)
    d3 = discriminator_block(d2, df*2)
    d4 = discriminator_block(d3, df*2, strides=2)
    d5 = discriminator_block(d4, df*4)
    d6 = discriminator_block(d5, df*4, strides=2)
    d7 = discriminator_block(d6, df*8)
    d8 = discriminator_block(d7, df*8, strides=2)

    d8_5 = Flatten()(d8)
    d9 = Dense(df*16)(d8_5)
    d10 = LeakyReLU(alpha=0.2)(d9)
    validity = Dense(1, activation='sigmoid')(d10)
    return Model(disc_ip, validity)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;br&gt;&lt;br&gt;
&lt;strong&gt;VGG19 Model&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;In this code block, we use the VGG19 model trained with image-net database to extract the features, this model is frozen later so that parameters won’t get updated.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;from keras.applications import VGG19
# Build the VGG19 model upto 10th layer 
# Used to extract the features of high res imgaes
def build_vgg():
    vgg = VGG19(weights="imagenet")
    vgg.outputs = [vgg.layers[9].output]
    img = Input(shape=hr_shape)
    img_features = vgg(img)
    return Model(img, img_features)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;br&gt;&lt;br&gt;
&lt;strong&gt;Combined Model&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Now, we attach both generator and discriminator model. The model obtained from this is used only to train the generator model. While training this combined model we have to freeze the discriminator in each epoch.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;# Attach the generator and discriminator
def create_comb(gen_model, disc_model, vgg, lr_ip, hr_ip):
    gen_img = gen_model(lr_ip)
    gen_features = vgg(gen_img)
    disc_model.trainable = False
    validity = disc_model(gen_img)
    return Model([lr_ip, hr_ip],[validity,gen_features])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;br&gt;&lt;br&gt;
&lt;strong&gt;Declare models&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Then, we declare generator, discriminator and vgg models. Those model will be used as arguments for the combined model.&lt;br&gt;
Any changes of the smaller models inside the combined model also affects the model outside. For example: weight updates, freezing the model, etc.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;generator = create_gen(lr_ip)
discriminator = create_disc(hr_ip)
discriminator.compile(loss="binary_crossentropy", optimizer="adam",      
  metrics=['accuracy'])
vgg = build_vgg()
vgg.trainable = False
gan_model = create_comb(generator, discriminator, vgg, lr_ip, hr_ip)
gan_model.compile(loss=["binary_crossentropy","mse"], loss_weights=
  [1e-3, 1], optimizer="adam")
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;br&gt;&lt;br&gt;
&lt;strong&gt;Sample the training data into small batches&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;As the training set is too large, we need to sample the images into small batches to avoid &lt;strong&gt;Resource Exhausted Error&lt;/strong&gt;. The resource such as RAM will not be enough to train all the images at once.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;batch_size = 20
train_lr_batches = []
train_hr_batches = []
for it in range(int(train_hr.shape[0] / batch_size)):
    start_idx = it * batch_size
    end_idx = start_idx + batch_size
    train_hr_batches.append(train_hr[start_idx:end_idx])
    train_lr_batches.append(train_lr[start_idx:end_idx])
train_lr_batches = np.array(train_lr_batches)
train_hr_batches = np.array(train_hr_batches)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;br&gt;&lt;br&gt;
&lt;strong&gt;Training the model&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;This block is the core of whole program. Here we train the discriminator and generator in the alternating method as mentioned above. As of now, the discriminator is frozen, do not forget to unfreeze before and freeze after training the discriminator, which is given in the code below.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;epochs = 100
for e in range(epochs):
    gen_label = np.zeros((batch_size, 1))
    real_label = np.ones((batch_size,1))
    g_losses = []
    d_losses = []
    for b in range(len(train_hr_batches)):
        lr_imgs = train_lr_batches[b]
        hr_imgs = train_hr_batches[b]
        gen_imgs = generator.predict_on_batch(lr_imgs)
        #Dont forget to make the discriminator trainable
        discriminator.trainable = True

        #Train the discriminator
        d_loss_gen = discriminator.train_on_batch(gen_imgs,
          gen_label)
        d_loss_real = discriminator.train_on_batch(hr_imgs,
          real_label)
        discriminator.trainable = False
        d_loss = 0.5 * np.add(d_loss_gen, d_loss_real)
        image_features = vgg.predict(hr_imgs)

        #Train the generator
        g_loss, _, _ = gan_model.train_on_batch([lr_imgs, hr_imgs], 
          [real_label, image_features])

        d_losses.append(d_loss)
        g_losses.append(g_loss)
    g_losses = np.array(g_losses)
    d_losses = np.array(d_losses)

    g_loss = np.sum(g_losses, axis=0) / len(g_losses)
    d_loss = np.sum(d_losses, axis=0) / len(d_losses)
    print("epoch:", e+1 ,"g_loss:", g_loss, "d_loss:", d_loss)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;br&gt;&lt;br&gt;
&lt;strong&gt;Evaluate the model&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Hereby, we calculate the performance of the generator with test dataset. The loss may be a little larger than with training dataset, but do not worry as long as long as the difference is small.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;label = np.ones((len(test_lr),1))
test_features = vgg.predict(test_hr)
eval,_,_ = gan_model.evaluate([test_lr, test_hr], [label,test_features])
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;&lt;br&gt;&lt;br&gt;
&lt;strong&gt;Predict the output&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;We can generate high resolution images with generator model.&lt;br&gt;
&lt;/p&gt;

&lt;div class="highlight js-code-highlight"&gt;
&lt;pre class="highlight plaintext"&gt;&lt;code&gt;test_prediction = generator.predict_on_batch(test_lr)
&lt;/code&gt;&lt;/pre&gt;

&lt;/div&gt;



&lt;p&gt;The output is quite amazing…&lt;br&gt;
&lt;a href="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2F3qo0i7gil7dnfpcrqijs.jpeg" class="article-body-image-wrapper"&gt;&lt;img src="https://media.dev.to/dynamic/image/width=800%2Cheight=%2Cfit=scale-down%2Cgravity=auto%2Cformat=auto/https%3A%2F%2Fdev-to-uploads.s3.amazonaws.com%2Fi%2F3qo0i7gil7dnfpcrqijs.jpeg" alt="SRGAN Output"&gt;&lt;/a&gt;&lt;/p&gt;

&lt;p&gt;You can find my &lt;a href="https://github.com/manishdhakal/SuperResolution" rel="noopener noreferrer"&gt;implementation&lt;/a&gt; which was trained on google colab in my github profile.&lt;/p&gt;




&lt;h2&gt;
  
  
  &lt;strong&gt;Tips&lt;/strong&gt;
&lt;/h2&gt;

&lt;ul&gt;
&lt;li&gt;Always remember which model to make trainable or not.&lt;/li&gt;
&lt;li&gt;While training the generator use the label value as one.&lt;/li&gt;
&lt;li&gt;It is better to use images larger than 25x25 as they have more details for generated images.&lt;/li&gt;
&lt;li&gt;Do not forget to normalize the numpy dataset between 0 and 1.&lt;/li&gt;
&lt;/ul&gt;




&lt;h2&gt;
  
  
  &lt;strong&gt;References&lt;/strong&gt;
&lt;/h2&gt;

&lt;p&gt;Jason Brownlee. 2019. &lt;em&gt;Generative Adversarial Networks with Python&lt;/em&gt;&lt;br&gt;
&lt;a href="https://arxiv.org/pdf/1609.04802" rel="noopener noreferrer"&gt;https://arxiv.org/pdf/1609.04802&lt;/a&gt;. &lt;em&gt;Paper on SRGAN&lt;/em&gt;&lt;/p&gt;

</description>
      <category>keras</category>
      <category>superresolution</category>
      <category>gan</category>
      <category>neuralnetwork</category>
    </item>
  </channel>
</rss>
