Table of Contents
- What Is a Variational Autoencoder?
- The Problem VAEs Solve
- Mathematical Foundations
- Architecture Deep Dive
- The ELBO Loss Function
- Implementation: VAE on MNIST with TensorFlow/Keras
- Visualising the Latent Space
- Conditional VAE (CVAE)
- Common Failure Modes and Fixes
- VAEs vs GANs vs Diffusion Models
- Real-World Use Cases
- Further Reading
1. What Is a Variational Autoencoder?
A Variational Autoencoder (VAE) is a generative model that learns a compressed, continuous, and structured probabilistic representation of data. Introduced by Kingma & Welling in 2013, it sits at the intersection of deep learning and Bayesian inference.
Unlike a standard autoencoder — which maps inputs to a fixed-point embedding — a VAE maps inputs to a probability distribution over a latent space. This single design choice unlocks the ability to:
- Generate new, coherent samples by sampling from the latent prior
- Interpolate smoothly between data points
- Disentangle factors of variation in the data
- Detect anomalies by measuring reconstruction probability
Think of a VAE as a lossy compression algorithm that has been forced to be well-organised: similar inputs cluster together in latent space, and the space between clusters is meaningful rather than empty.
2. The Problem VAEs Solve
Standard Autoencoders Are Broken Generative Models
A standard autoencoder consists of an encoder E and decoder D:
x → E(x) = z → D(z) = x̂
The encoder collapses each input to a single point z. The problem is that the latent space is irregular — there is no guarantee that the region between two known encodings corresponds to anything meaningful. Interpolating between the encoding of a "3" and a "7" in MNIST may pass through a void that the decoder maps to noise.
VAEs force the encoder to predict a Gaussian distribution q(z|x) = N(μ, σ²) instead of a point. The decoder then reconstructs from a sample drawn from that distribution. Two regularisation pressures emerge naturally:
Reconstruction loss: The decoder must recover x faithfully.
KL divergence: The predicted distribution must stay close to a standard Normal N(0, I).
The tension between these two objectives shapes the latent space into something smooth, complete, and generative.
3. Mathematical Foundations
Generative Model
VAEs assume data x is generated by a two-step process:
z ~ p(z) = N(0, I) (latent prior)
x ~ p(x|z) = Decoder(z) (likelihood)
The true posterior p(z|x) is intractable (requires integrating over all z). VAEs approximate it with a learned distribution q_φ(z|x) — the encoder.
The KL Divergence Term
KL divergence measures how far q_φ(z|x) is from the prior p(z). For two Gaussians, it has a closed form:
KL(N(μ, σ²) || N(0, I)) = -½ Σ (1 + log σ² - μ² - σ²)
This is cheap to compute, differentiable, and pushes the encoder toward organised, zero-centred distributions.
The Reparameterisation Trick
Sampling z ~ N(μ, σ²) is not differentiable — gradients cannot flow through a stochastic node. The fix is elegant:
z = μ + σ ⊙ ε, where ε ~ N(0, I)
Now ε is the randomness and μ, σ are deterministic outputs of the encoder. Gradients flow through μ and σ unobstructed.
4. Architecture Deep Dive
Encoder — maps x → (μ, log σ²). Two parallel output heads share a common backbone.
Sampling layer — implements the reparameterisation trick. No learnable parameters.
Decoder — maps z → x̂. Mirrors the encoder architecture. Output activation depends on data type (sigmoid for binary, linear for continuous).
β-VAE — a variant that scales the KL term by a factor β > 1, enforcing stronger disentanglement at the cost of reconstruction fidelity.
5. The ELBO Loss Function
The VAE is trained to maximise the Evidence Lower BOund (ELBO):
ELBO = E[log p(x|z)] - KL(q_φ(z|x) || p(z))
In practice this becomes a minimisation objective:
Loss = Reconstruction Loss + KL Loss
= -E[log p(x|z)] + KL(q_φ(z|x) || p(z))
Reconstruction loss is typically:
Binary cross-entropy for image pixels in [0, 1]
Mean squared error for continuous data
KL loss uses the closed-form expression above.
The ELBO lower-bounds log p(x), so maximising it is equivalent to maximising the marginal likelihood of the data — the holy grail of unsupervised learning.
6. Implementation: VAE on MNIST with TensorFlow/Keras
6.1 Setup
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
# Reproducibility
tf.random.set_seed(42)
np.random.seed(42)
# Hyperparameters
LATENT_DIM = 2 # 2D for easy visualisation; use 64–256 for real tasks
EPOCHS = 30
BATCH_SIZE = 128
BETA = 1.0 # KL weight; increase for β-VAE disentanglement
6.2 Load and Preprocess Data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Normalise to [0, 1] and flatten
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)
print(f"Train: {x_train.shape} | Test: {x_test.shape}")
# Train: (60000, 784) | Test: (10000, 784)
6.3 Sampling Layer (Reparameterisation Trick)
class Sampling(layers.Layer):
"""
Reparameterisation trick: z = μ + σ * ε, ε ~ N(0, I)
Allows gradients to flow through the stochastic sampling step.
"""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
eps = tf.random.normal(shape=(batch, dim)) # ε ~ N(0, I)
return z_mean + tf.exp(0.5 * z_log_var) * eps # z = μ + σ·ε
6.4 Encoder
def build_encoder(latent_dim: int) -> keras.Model:
inputs = keras.Input(shape=(784,), name="encoder_input")
x = layers.Dense(512, activation="relu")(inputs)
x = layers.Dense(256, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = Sampling(name="z")([z_mean, z_log_var])
return keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
encoder = build_encoder(LATENT_DIM)
encoder.summary()
6.5 Decoder
def build_decoder(latent_dim: int) -> keras.Model:
inputs = keras.Input(shape=(latent_dim,), name="decoder_input")
x = layers.Dense(256, activation="relu")(inputs)
x = layers.Dense(512, activation="relu")(x)
# Sigmoid output: pixels treated as Bernoulli probabilities
outputs = layers.Dense(784, activation="sigmoid", name="decoder_output")(x)
return keras.Model(inputs, outputs, name="decoder")
decoder = build_decoder(LATENT_DIM)
decoder.summary()
6.6 VAE Model with Custom Training Step
class VAE(keras.Model):
def __init__(self, encoder, decoder, beta=1.0, **kwargs):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.beta = beta
# Metrics tracked across batches
self.total_loss_tracker = keras.metrics.Mean(name="loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(name="recon_loss")
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
# Reconstruction loss: binary cross-entropy summed over pixels
recon_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction),
axis=-1
)
)
# KL divergence: closed-form for N(μ, σ²) vs N(0, I)
kl_loss = -0.5 * tf.reduce_mean(
tf.reduce_sum(
1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var),
axis=1
)
)
total_loss = recon_loss + self.beta * kl_loss
grads = tape.gradient(total_loss, self.trainable_variables)
self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(recon_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {m.name: m.result() for m in self.metrics}
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
return self.decoder(z)
6.7 Train
vae = VAE(encoder, decoder, beta=BETA)
vae.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3))
history = vae.fit(
x_train,
epochs=EPOCHS,
batch_size=BATCH_SIZE,
validation_data=(x_test, x_test),
verbose=1
)
Expected output after 30 epochs
Epoch 30/30
469/469 ━━━━━━━━━━━━ 3s 6ms/step
loss: 162.4 recon_loss: 155.2 kl_loss: 7.2 val_loss: 165.1
Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for ax, key in zip(axes, ["loss", "recon_loss", "kl_loss"]):
ax.plot(history.history[key], label="train")
ax.plot(history.history[f"val_{key}"], label="val", linestyle="--")
ax.set_title(key.replace("_", " ").title())
ax.set_xlabel("Epoch")
ax.legend()
plt.tight_layout()
plt.savefig("training_curves.png", dpi=150)
plt.show()
7. Visualising the Latent Space
7.1 Encode the Test Set and Plot μ
z_mean, _, _ = encoder.predict(x_test, batch_size=256)
plt.figure(figsize=(8, 6))
scatter = plt.scatter(z_mean[:, 0], z_mean[:, 1],
c=y_test, cmap="tab10", alpha=0.5, s=2)
plt.colorbar(scatter, label="Digit class")
plt.xlabel("z₁")
plt.ylabel("z₂")
plt.title("VAE Latent Space — MNIST Test Set")
plt.tight_layout()
plt.savefig("latent_space.png", dpi=150)
plt.show()
With a well-trained 2D VAE you will see the 10 digit classes arranged in a smooth, approximately circular manifold. Classes that look similar (3/8, 4/9) naturally cluster closer together.
7.2 Decode a 2D Grid of Latent Points
This reveals what the decoder has learned across the full latent space:
n = 20 # Grid size: 20×20 = 400 images
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
# Sample a grid of z values between the 5th and 95th percentiles
grid_x = np.linspace(-3, 3, n)
grid_y = np.linspace(-3, 3, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = decoder.predict(z_sample, verbose=0)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[
i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size,
] = digit
plt.figure(figsize=(12, 12))
plt.imshow(figure, cmap="Greys_r")
plt.title("VAE Latent Space Manifold", fontsize=16)
plt.axis("off")
plt.tight_layout()
plt.savefig("latent_manifold.png", dpi=150)
plt.show()
The manifold plot is VAEs' signature strength: you can watch digits continuously morph into one another as you traverse the latent space. The transition from "1" to "7" passes through recognisable intermediate forms rather than noise.
7.3 Interpolation Between Two Images
def interpolate(vae, x_a, x_b, steps=10):
"""Linearly interpolate between two images via their latent means."""
z_a, _, _ = vae.encoder(x_a[np.newaxis])
z_b, _, _ = vae.encoder(x_b[np.newaxis])
alphas = np.linspace(0, 1, steps)
z_interp = np.array([(1 - a) * z_a + a * z_b for a in alphas])
z_interp = z_interp.reshape(steps, LATENT_DIM)
images = vae.decoder(z_interp).numpy()
return images.reshape(steps, 28, 28)
# Interpolate between a "3" and an "8"
idx_a = np.where(y_test == 3)[0][0]
idx_b = np.where(y_test == 8)[0][0]
frames = interpolate(vae, x_test[idx_a], x_test[idx_b], steps=12)
fig, axes = plt.subplots(1, 12, figsize=(18, 2))
for ax, frame in zip(axes, frames):
ax.imshow(frame, cmap="Greys_r")
ax.axis("off")
plt.suptitle("Latent Interpolation: 3 → 8", fontsize=13)
plt.tight_layout()
plt.savefig("interpolation.png", dpi=150)
plt.show()
8. Conditional VAE (CVAE)
A standard VAE generates samples without control over their class. A Conditional VAE solves this by passing the label y to both encoder and decoder as a one-hot vector concatenated with the input.
NUM_CLASSES = 10
def build_cvae_encoder(latent_dim):
x_in = keras.Input(shape=(784,))
y_in = keras.Input(shape=(NUM_CLASSES,))
h = layers.Concatenate()([x_in, y_in])
h = layers.Dense(512, activation="relu")(h)
h = layers.Dense(256, activation="relu")(h)
z_mean = layers.Dense(latent_dim)(h)
z_log_var = layers.Dense(latent_dim)(h)
z = Sampling()([z_mean, z_log_var])
return keras.Model([x_in, y_in], [z_mean, z_log_var, z], name="cvae_encoder")
def build_cvae_decoder(latent_dim):
z_in = keras.Input(shape=(latent_dim,))
y_in = keras.Input(shape=(NUM_CLASSES,))
h = layers.Concatenate()([z_in, y_in])
h = layers.Dense(256, activation="relu")(h)
h = layers.Dense(512, activation="relu")(h)
out = layers.Dense(784, activation="sigmoid")(h)
return keras.Model([z_in, y_in], out, name="cvae_decoder")
With a trained CVAE, you can generate digit "7" on demand:
pythonlabel = tf.one_hot([7], NUM_CLASSES) # generate a "7"
z_sample = tf.random.normal(shape=(1, LATENT_DIM))
generated = cvae_decoder([z_sample, label])
9. Common Failure Modes and Fixes
Posterior Collapse
Symptom: KL loss drops to near zero early; the decoder ignores z entirely and generates blurry means.
Cause: The decoder is powerful enough to reconstruct x from context alone (common with autoregressive decoders).
Fixes:
KL annealing: start β = 0 and linearly ramp to 1 over the first few epochs
Use a Free Bits constraint: only penalise KL below a minimum threshold per dimension
Reduce decoder capacity
# KL annealing schedule
def kl_weight(epoch, total_epochs=30, warmup=10):
if epoch < warmup:
return epoch / warmup
return 1.0
Blurry Reconstructions
Symptom: Reconstructions look washed out and lack fine detail.
Cause: Binary cross-entropy (or MSE) averages pixel uncertainty, producing the mean of all plausible images.
Fixes:
Use a more expressive decoder (convolutional, residual blocks)
Switch to a perceptual loss (VGG feature matching)
Use a flow-based or autoregressive decoder
Consider moving to a diffusion model for the highest quality
Mode Dropping
Symptom: The model generates only a subset of classes.
Fixes:
Check that batches are class-balanced
Increase β to prevent over-fitting to easy modes
Use a CVAE to give the model explicit class supervision
When to choose a VAE:
You need a compact, interpretable latent space
You need fast inference (real-time applications)
You need anomaly detection or data imputation
You need smooth interpolation between data points
You are working with structured/tabular data (not raw pixels)
11. Real-World Use Cases
Anomaly Detection
Compute the per-sample ELBO on held-out data. Samples with very low ELBO (high reconstruction loss + high KL) are anomalies.
def anomaly_score(vae, x, n_samples=50):
"""
Monte Carlo estimate of -ELBO as anomaly score.
Higher = more anomalous.
"""
scores = []
for _ in range(n_samples):
z_mean, z_log_var, z = vae.encoder(x)
x_recon = vae.decoder(z)
recon = tf.reduce_sum(
keras.losses.binary_crossentropy(x, x_recon), axis=-1
)
kl = -0.5 * tf.reduce_sum(
1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1
)
scores.append(recon + kl)
return tf.reduce_mean(scores, axis=0).numpy()
Drug Discovery
VAEs are used to learn a continuous, differentiable latent space over molecular SMILES strings. Gradient-based optimisation in latent space — then decoding back to molecules — explores chemical property landscapes far more efficiently than random search.
*Recommendation Systems
*
Variational Autoencoders for Collaborative Filtering (VAE-CF) outperform matrix factorisation on sparse user-item interaction matrices by learning non-linear item representations and regularising via KL divergence.
Time-Series Imputation and Generation
VAEs naturally handle missing data: mask unknown timesteps in the reconstruction loss, and the model learns to impute from the latent distribution.
Text Generation (with modifications)
Text VAEs require careful treatment because discrete tokens break the reparameterisation trick. Solutions include the Gumbel-Softmax trick or a continuous approximation of the token embedding space.

Top comments (0)