DEV Community

Gruhesh Sri Sai Karthik Kurra
Gruhesh Sri Sai Karthik Kurra

Posted on

Building Variational Autoencoders from Scratch: A Complete PyTorch Implementation

Ever wondered how AI models can generate new images that look remarkably similar to real ones? Today, I'll walk you through building a Variational Autoencoder (VAE) from scratch using PyTorch - one of the most elegant generative models in deep learning!

🎯 What We'll Build

In this tutorial, we'll create a complete VAE implementation that can:

  • ✨ Generate new handwritten digits
  • 🔍 Compress images into meaningful 2D representations
  • 🎨 Smoothly interpolate between different digits
  • 📊 Visualize learned latent spaces

VAE Results

🧠 What is a Variational Autoencoder?

A VAE is like a smart compression algorithm that learns to:

  1. Encode images into a compact latent space
  2. Sample from learned probability distributions
  3. Decode samples back into realistic images

Unlike regular autoencoders, VAEs add a probabilistic twist - they learn distributions rather than fixed points, enabling generation of new data!

🏗️ Architecture Overview

Our VAE consists of three main components:

class VAE(nn.Module):
    def __init__(self, input_dim=784, latent_dim=2, hidden_dim=256, beta=1.0):
        super(VAE, self).__init__()

        # Encoder: Image → Latent Distribution Parameters
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(True),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(True),
            nn.Dropout(0.2)
        )

        # Latent space parameters
        self.fc_mu = nn.Linear(hidden_dim // 2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim // 2, latent_dim)

        # Decoder: Latent → Reconstructed Image
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(True),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(True),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
Enter fullscreen mode Exit fullscreen mode

🔑 The Magic: Reparameterization Trick

The heart of VAEs lies in the reparameterization trick, which allows gradients to flow through random sampling:

def reparameterize(self, mu, logvar):
    """
    Sample z = μ + σ * ε where ε ~ N(0,1)
    This makes sampling differentiable!
    """
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std
Enter fullscreen mode Exit fullscreen mode

📈 The Loss Function: Balancing Act

VAEs optimize two competing objectives:

def loss_function(self, recon_x, x, mu, logvar):
    # Reconstruction Loss: How well can we rebuild the input?
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')

    # KL Divergence: Keep latent space well-behaved
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    # Total VAE Loss
    total_loss = recon_loss + self.beta * kl_loss
    return total_loss, recon_loss, kl_loss
Enter fullscreen mode Exit fullscreen mode

🚀 Training Results

After training on MNIST for 20 epochs, our VAE achieves impressive results:

📊 Training Metrics

  • Final Training Loss: ~85.2
  • Reconstruction Loss: ~83.5
  • KL Divergence: ~1.7

Training Curves

🎨 Latent Space Visualization

The most exciting part - our 2D latent space beautifully organizes digits into clusters:

Latent Space

🔄 Reconstruction Quality

Original vs. reconstructed digits show excellent quality:

Reconstructions

🌊 Smooth Interpolations

Watch digits smoothly transform into each other:

Interpolation

💡 Key Features of Our Implementation

🛠️ Production-Ready Code

  • Modular Design: Separate classes for model, trainer, logger, visualizer
  • Comprehensive Logging: Track all metrics during training
  • Automatic Checkpointing: Save best models automatically
  • Rich Visualizations: Generate publication-ready plots

📚 Educational Value

  • Detailed Comments: Every line explained
  • Mathematical Background: Complete derivations included
  • Visualization Examples: Understand what VAEs learn
  • Training Analysis: Monitor and improve performance

🎯 Real-World Applications

This VAE implementation can be adapted for:

  • 🎨 Art Generation: Create new artistic styles
  • 🔍 Anomaly Detection: Identify unusual patterns
  • 📊 Data Compression: Efficient representation learning
  • 🔄 Data Augmentation: Generate synthetic training data
  • 🧬 Drug Discovery: Generate new molecular structures

🚀 Try It Yourself!

Want to experiment with VAEs? Here's how to get started:

GitHub Repository

git clone https://github.com/GruheshKurra/VariationalAutoencoders.git
cd VariationalAutoencoders
pip install torch torchvision matplotlib pandas numpy seaborn
jupyter notebook Untitled.ipynb
Enter fullscreen mode Exit fullscreen mode

Hugging Face Model Hub

Check out the pre-trained model and detailed documentation:
🤗 karthik-2905/VariationalAutoencoders

🔧 Customization Ideas

Experiment with different configurations:

# β-VAE for better disentanglement
vae = VAE(latent_dim=10, beta=4.0)

# Larger model for complex datasets
vae = VAE(hidden_dim=512, latent_dim=64)

# Different datasets
# Try CIFAR-10, CelebA, or your own data!
Enter fullscreen mode Exit fullscreen mode

📝 Key Takeaways

  1. VAEs balance reconstruction and regularization through their dual loss function
  2. The reparameterization trick enables end-to-end training of generative models
  3. 2D latent spaces provide excellent visualization opportunities
  4. Proper logging and visualization are crucial for understanding model behavior
  5. Modular code design makes experimentation easier

🔮 What's Next?

This implementation opens doors to explore:

  • β-VAEs for better disentanglement
  • Conditional VAEs for controlled generation
  • Hierarchical VAEs for complex data
  • VQ-VAEs for discrete representations

🤝 Connect & Contribute

Found this helpful? Let's connect and build amazing AI together!

Have questions or want to contribute? Open an issue or submit a PR!


Happy coding, and may your latent spaces be well-organized! 🎓✨

DeepLearning #PyTorch #GenerativeAI #MachineLearning #VAE #AI #OpenSource

Top comments (0)