DEV Community

Cover image for Generative Adversarial Networks
Idiomatic Programmers for Idiomatic Programmers

Posted on • Originally published at idiomaticprogrammers.com on

Generative Adversarial Networks

In 2014, Ian Goodfellow, who is currently working as Director of Machine Learning for Apple, published a paper called "Generative Adversarial Networks" or GAN for short. Which basically talked about a system of two neural networks, Generator and Discriminator, that can generate images or any data that is similar to provided datasets from essentially random noise.

The diagram below shows the basic architecture of a GAN.

https://res.cloudinary.com/idiomprog/image/upload/v1613042386/gan_diagram_jhuerz.svg

Source: https://developers.google.com/

Generative Network takes some random noise and outputs some random noise. This output noise is passed to a discriminator along with the real image or data as a ground truth based on that both Discriminator and Generator is trained.

As you can see, the concept of GAN is very simple. When I started learning about GANs I thought I can easily implement one of those. But boy I was wrong. In reality training a GAN is extremely hard both Generator and Discriminator must be trained side by side if one overpowers the other it won't work, we will talk about all the problems you might face while training a GAN. But now let's look at a simple GAN in Pytorch. We will be using the MNIST Dataset for this post.

Understanding the Dataset

The MNIST dataset is a huge database of handwritten numbers from 0 to 9 used for Optical Character Recognition or reading numbers from an image.

This dataset consists of 28x28 images of handwritten numbers where each pixel contains either a zero or a one.

https://res.cloudinary.com/idiomprog/image/upload/v1613042427/687474703a2f2f692e7974696d672e636f6d2f76692f3051493378675875422d512f687164656661756c742e6a7067_rx3nur.jpg

The computer vision extension of PyTorch, Torchvision, provides this dataset which we can download using the following code snippet.

mnist = datasets.MNIST(root='datasets', train=True, transform=transformations, download=True)
Enter fullscreen mode Exit fullscreen mode

But before we can run this code, we need to import some libraries.

import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
Enter fullscreen mode Exit fullscreen mode

You can install the torchvision library using this pip command.

pip install torchvision
Enter fullscreen mode Exit fullscreen mode

The DataLoader class is used to load the data to memory in batches, this prevents your system from running out of memory while training.

Transforms class is used to make random augmentations to the image such as random rotation, resize, crop, etc. but for now we will only normalise the images to range from -1 to 1.

The entire function that you can copy paste is this.

def load_dataset():
    transformations = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, ))
    ])
    mnist = datasets.MNIST(root='datasets', train=True, transform=transformations, download=True)
    return DataLoader(mnist, batch_size=32, shuffle=True)
Enter fullscreen mode Exit fullscreen mode

Implementing the Generator

In this section, we will implement a generator network which will take a random noise vector of size 100 and convert it into a vector of size 784 which we will then convert to 28x28

class Generator(nn.Module):
    def __init__ (self):
        super(). __init__ ()
        self.gen = nn.Sequential(
            nn.Linear(100, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, 784),
            nn.Tanh(), # make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)
Enter fullscreen mode Exit fullscreen mode

We will be implementing the original GAN created by Ian Goodfellow in this paper Generative Adversarial Networks.

Implemeting the Discriminator

Next, we will implement a discriminator that will take the vector of size 784 that may be generated or from a real image.

class Discriminator(nn.Module):
    def __init__ (self):
        super(). __init__ ()
        self.disc = nn.Sequential(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(), # make outputs [0, 1]
        )

    def forward(self, x):
        return self.disc(x)
Enter fullscreen mode Exit fullscreen mode

Hyperparameters

As I mentioned earlier, GANs are extremely hard to train, one of the reasons is that a GAN is very sensitive to the initial values or hyperparameters. You have to follow the original papers to get the training right. Otherwise, you might have to spend many days optimizing the hyperparameters.

config = {
    'device': "cuda" if torch.cuda.is_available() else "cpu",
    'lr': 3e-4,
    'epochs': 50,
        'batch_size': 32
}
Enter fullscreen mode Exit fullscreen mode

In our case, we will be using the same parameters as it was said in the paper, that is, a learning rate of 0.0003 and 50 epochs.

What is a learning rate?

Simple answer, it's the rate at which a machine learning model learns. Smaller the number, the slower it learns and higher the number the faster it learns. For more info, check out our post about Perceptrons.

Training Time

Finally, it's time to actually generate some handwritten numbers, that is, train our GAN. First let's create the objects for Generator, Discriminator, and their optimizers, we will be using the Adam optimizers.

ADAM or ADA ptive M oment optimiser is an algorithm used to update the weights such that the overall error goes down.

disc = Discriminator().to(config['device'])
gen = Generator().to(config['device'])

optimiser_g = optim.Adam(params=gen.parameters(), lr=config['lr'])
optimiser_d = optim.Adam(params=disc.parameters(), lr=config['lr'])
Enter fullscreen mode Exit fullscreen mode

We also need to define a loss function before we train, we will be using the Binary Cross Entropy loss function to calculate the error of our model.

Binary Cross Entropy (BCE) Loss

This loss formula is used to calculate the distance between two probability distributions.

BCE=ℓ(x,y)=L=(l1​,…,lN​),ln​=−wn​[yn​⋅logxn​+(1−yn​)⋅log(1−xn​)]

BCE = ℓ(x,y)=L=(l_1​,…,l_N​), l_n​=−w_n​[y_n​⋅logx_n​+(1−y_n​)⋅log(1−x_n​)]

BCE=ℓ(x,y)=L=(l1​​,…,lN​​),ln​​=−wn​​[yn​​⋅logxn​​+(1−yn​​)⋅log(1−xn​​)]

loss_fn = nn.BCELoss()
Enter fullscreen mode Exit fullscreen mode

Now we will write our training step which is one cycle of generating and discriminating a handwritten number compare with an original number and update the models.

for epoch in range(config['epoch']):
    for batch_idx, (real, label) in enumerate(train_data):

        noise = torch.randn(config['batch_size'], 100).to(config['device']) # Create a random probability distribution
        fake = gen(noise) # Generate a fake number
        disc_real = disc(real).view(-1) # pass the real number through the discriminator
        lossD_real = loss_fn(disc_real, torch.ones_like(disc_real)) # calculate the loss for real image
        disc_fake = disc(fake).view(-1) # pass the fake number through the discriminator
        lossD_fake = loss_fn(disc_fake, torch.zeros_like(disc_fake)) # calculate the loss for fake image
        lossD = (lossD_real + lossD_fake) / 2 # calculate the average loss

        # update the weights for the discriminator
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        optimiser_d.step()

        output = disc(fake).view(-1)
        lossG = loss_fn(output, torch.ones_like(output)) # calculate the error between the fake image and the true image

        # update the weights for the generator
        gen.zero_grad()
        lossG.backward()
        optimiser_g.step()
Enter fullscreen mode Exit fullscreen mode

This code is heavily inspired by a youtube video by Aladdin Persson.

Find the complete code at our Github Repo

If you have any questions, feel free to comment below.

References

  1. Goodfellow, Ian J., et al. “Generative Adversarial Networks.” ArXiv.org, 10 June 2014, arxiv.org/abs/1406.2661v1.
  2. Kingma, Diederik P., and Jimmy Ba. “Adam: A Method for Stochastic Optimization.” ArXiv.org, 30 Jan. 2017, arxiv.org/abs/1412.6980.
  3. “THE MNIST DATABASE.” MNIST Handwritten Digit Database, Yann LeCun, Corinna Cortes and Chris Burges, yann.lecun.com/exdb/mnist/.
  4. Understanding Categorical Cross-Entropy Loss, Binary Cross-Entropy Loss, Softmax Loss, Logistic Loss, Focal Loss and All Those Confusing Names, gombru.github.io/2018/05/23/cross_entropy_loss/.

Top comments (0)