Welcome, fellow developers, to an exciting journey into the realm of generative models! In this comprehensive tutorial, we will unravel the mysteries of Variational Autoencoders (VAEs), a powerful and elegant deep learning architecture capable of learning complex data distributions and generating novel, realistic data samples. Whether you are looking to generate new images, understand data structures, or simply expand your machine learning toolkit, VAEs offer a refreshing perspective on unsupervised learning.
This article is structured to provide a clear, step-by-step understanding. We will begin by laying a solid theoretical foundation, exploring the core concepts that make VAEs so effective. Following that, we will dive into the practical aspects, guiding you through the process of building and training your own VAE using Python and a popular deep learning framework. Our goal is to make this journey as insightful and enjoyable as possible, integrating small code snippets and conceptual explanations to illuminate every detail.
Let's embark on this adventure to master Variational Autoencoders!
Section 1: Theoretical Foundations of Variational Autoencoders
To truly appreciate Variational Autoencoders, it is beneficial to first understand their simpler cousin: the standard Autoencoder.
What is an Autoencoder?
Just to repeat what I explained in Part 1: An Autoencoder is a type of artificial neural network designed to learn efficient data codings in an unsupervised manner. The core idea behind an Autoencoder is to train a network to reconstruct its input. This network consists of two main parts: an encoder and a decoder. The encoder takes the input data and transforms it into a lower-dimensional representation, often called the latent vector or bottleneck layer. This latent vector captures the most significant features of the input data. The decoder then takes this latent vector and attempts to reconstruct the original input from it.
The training objective for a standard Autoencoder is to minimize the reconstruction error between the input and the output. For example, if we are working with images, the reconstruction error might be the Mean Squared Error (MSE) between the input image pixels and the reconstructed image pixels. Through this process, the Autoencoder learns a compressed representation of the data, effectively performing dimensionality reduction and feature extraction.
Imagine the Autoencoder architecture conceptually as follows:
Input Data
|
V
Encoder (compresses data)
|
V
Latent Vector (bottleneck)
|
V
Decoder (reconstructs data)
|
V
Reconstructed Output
Limitations of Standard Autoencoders
While standard Autoencoders are excellent for dimensionality reduction and learning features, they possess a significant limitation, especially when it comes to generative tasks. The latent space learned by a standard Autoencoder is not explicitly regularized. This means that the latent vectors for similar inputs might not necessarily be close to each other in the latent space, and there might be large "gaps" or discontinuous regions. If you were to sample a random point from this unregularized latent space and pass it through the decoder, the resulting output might be meaningless or not resemble any real data point. This makes standard Autoencoders unsuitable for generating new, coherent data samples.
Introducing Variational Autoencoders (VAEs)
Variational Autoencoders address the limitations of standard Autoencoders by introducing a probabilistic twist. Instead of the encoder producing a single, fixed latent vector for each input, a VAE's encoder outputs the parameters of a probability distribution in the latent space. This distribution is typically assumed to be a Gaussian (normal) distribution.
The fundamental goal of a VAE is to learn a latent space that is continuous, smooth, and well-structured, allowing for meaningful interpolation and generation of new data. By forcing the latent space to conform to a known probability distribution (like a standard normal distribution), VAEs ensure that different regions of the latent space correspond to meaningful variations in the data. This regularization prevents the latent space from having arbitrary gaps and encourages similar data points to cluster together.
The Latent Space and its Importance
In a VAE, the latent space is a low-dimensional representation of the input data, but critically, it is a *probabilistic* latent space. For each input data point, the encoder does not just produce a single point in this space; it produces a distribution over possible points. This distribution is usually a multivariate Gaussian distribution, which is fully characterized by its mean vector (\(\mu\)) and its covariance matrix (or more commonly, a diagonal covariance matrix, represented by a vector of variances \(\sigma^2\)).
The importance of this probabilistic latent space cannot be overstated. It ensures that when we sample from this space, we are likely to get a meaningful latent vector, which the decoder can then transform into a realistic data sample. This structure is what enables VAEs to be powerful generative models.
The Probabilistic Approach: Encoder as a Distribution
For each input data point \(x\), the VAE's encoder does not output a single latent vector \(z\). Instead, it outputs two vectors that define a probability distribution for \(z\). These two vectors are:
1. The mean vector, \(\mu\) (mu): This vector represents the center of the Gaussian distribution in the latent space. It indicates where the latent representation for the input \(x\) is most likely to be found.
2. The log-variance vector, \(\log \sigma^2\) (log-sigma-squared): This vector represents the spread or uncertainty of the Gaussian distribution. We often use log-variance instead of variance directly for numerical stability and to ensure that the variance is always positive (since \(\exp(\text{anything})\) is always positive). From \(\log \sigma^2\), we can easily get the standard deviation \(\sigma\) by taking the exponential of half the log-variance: \(\sigma = \exp(0.5 \times \log \sigma^2)\).
Once the encoder has produced \(\mu\) and \(\log \sigma^2\), a sample \(z\) is drawn from the distribution \(\mathcal{N}(\mu, \sigma^2)\). This sampled \(z\) is then passed to the decoder to reconstruct the input.
The Reparameterization Trick
A crucial challenge arises when we sample \(z\) from the latent distribution: the sampling operation is inherently non-differentiable. This means that we cannot directly backpropagate gradients through the sampling step to train the encoder. This is where the ingenious "reparameterization trick" comes into play.
Instead of directly sampling \(z \sim \mathcal{N}(\mu, \sigma^2)\), we can express \(z\) as a deterministic function of \(\mu\), \(\sigma\), and a random variable \(\epsilon\) (epsilon) sampled from a standard normal distribution \(\mathcal{N}(0, 1)\). The formula for this transformation is:
$$z = \mu + \sigma \times \epsilon$$
Here, \(\epsilon\) is a random variable sampled from a standard normal distribution, which means it has a mean of 0 and a standard deviation of 1. By using this trick, the randomness is now external to the network's trainable parameters (\(\mu\) and \(\sigma\)). The gradients can flow through \(\mu\) and \(\sigma\) during backpropagation, allowing the encoder to be trained effectively.
Conceptually, the flow looks like this:
Input Data
|
V
Encoder Network
|
V
Outputs: \(\mu\) and \(\log \sigma^2\)
|
V
Calculate \(\sigma = \exp(0.5 \times \log \sigma^2)\)
|
V
Sample \(\epsilon \sim \mathcal{N}(0, 1)\)
|
V
Reparameterization Trick: \(z = \mu + \sigma \times \epsilon\)
|
V
Decoder Network
|
V
Reconstructed Output
The VAE Loss Function: Reconstruction Loss + KL Divergence
The total loss function for a VAE is composed of two main terms, each serving a distinct purpose:
$$ \text{Total Loss} = \text{Reconstruction Loss} + \text{KL Divergence Loss} $$
1. Reconstruction Loss: This term measures how accurately the decoder reconstructs the original input data from the sampled latent vector \(z\). Its purpose is identical to the loss function in a standard Autoencoder.
- For continuous data, such as pixel values in images that have been normalized to a range like \([0, 1]\), the Binary Cross-Entropy (BCE) is a common choice. This is particularly true when the final layer of the decoder uses a sigmoid activation, outputting values between 0 and 1, which can be interpreted as probabilities.
- For other types of continuous data, or when the output is not probability-like, Mean Squared Error (MSE) is often used.
The reconstruction loss ensures that the VAE can effectively learn to reproduce the input data.
2. KL Divergence Loss (Kullback-Leibler Divergence): This is the regularization term that makes a VAE "variational" and differentiates it from a standard Autoencoder. The KL Divergence measures the difference between the probability distribution learned by the encoder (the latent distribution \(q(z|x) = \mathcal{N}(\mu, \sigma^2)\)) and a predefined prior distribution \(p(z)\) over the latent space. The prior distribution is typically chosen to be a standard normal distribution, \(\mathcal{N}(0, 1)\), which has a mean of 0 and a variance of 1.
The KL Divergence term encourages the encoder to produce latent distributions that are close to the prior distribution. This has several crucial benefits:
- Smooth Latent Space: It forces the latent space to be continuous and smooth, preventing the "gaps" found in standard Autoencoders.
- Meaningful Sampling: By aligning the learned latent distributions with a standard normal distribution, we can later sample random \(z\) vectors directly from \(\mathcal{N}(0, 1)\) and expect the decoder to generate meaningful data, as the VAE was trained to map these distributions.
- Regularization: It acts as a regularizer, preventing the encoder from learning an overly complex or "lazy" representation where \(\sigma^2\) becomes very small, essentially collapsing to a point like a standard Autoencoder.
The analytical formula for the KL Divergence between a Gaussian distribution \(\mathcal{N}(\mu, \sigma^2)\) and a standard normal distribution \(\mathcal{N}(0, 1)\) is:
$$ \text{KL Divergence} = -0.5 \sum_{i=1}^{D} (1 + \log \sigma_i^2 - \mu_i^2 - \sigma_i^2) $$
where \(D\) is the dimensionality of the latent space.
By combining these two loss terms, the VAE learns to both accurately reconstruct its inputs and organize its latent space in a structured, generative manner. The balance between these two terms is critical; too much emphasis on reconstruction might lead to a less regularized latent space, while too much emphasis on KL divergence might lead to poor reconstruction quality.
Section 2: Practical Implementation Guide
Now that we have a solid grasp of the theoretical underpinnings, let's translate these concepts into working code. We will use TensorFlow and Keras to build a VAE for the MNIST dataset, which consists of handwritten digits. Our running example will demonstrate how to define the encoder, decoder, the VAE model, its loss function, and finally, how to train it and generate new digits.
Setting up the Environment
First, we need to import the necessary libraries. We will be using TensorFlow and Keras, which provide high-level APIs for building and training neural networks. We will also use NumPy for numerical operations and Matplotlib for visualizing our results.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
Next, we will load and preprocess the MNIST dataset. The MNIST dataset contains 60,000 training images and 10,000 test images of handwritten digits (0-9). Each image is 28x28 pixels.
# Load the MNIST dataset
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
# Normalize and reshape the images
# We normalize pixel values to be between 0 and 1
# We reshape the images from 28x28 to a flattened vector of 784 pixels
image_size = x_train.shape[1] * x_train.shape[2]
x_train = np.reshape(x_train, [-1, image_size]).astype('float32') / 255
x_test = np.reshape(x_test, [-1, image_size]).astype('float32') / 255
Defining the Encoder Network
The encoder network takes an image as input and outputs the parameters of the latent distribution: the mean (\(\mu\)) and the log-variance (\(\log \sigma^2\)). Our encoder will consist of several dense layers.
# Define the input shape for our images (flattened 28x28 = 784)
input_shape = (image_size,)
# Define the size of our latent dimension
latent_dim = 2
# Create the encoder model
def build_encoder(input_shape, latent_dim):
# Input layer for the flattened image
encoder_inputs = keras.Input(shape=input_shape, name='encoder_input')
# First hidden dense layer with ReLU activation
x = layers.Dense(512, activation='relu', name='encoder_dense_1')(encoder_inputs)
# Second hidden dense layer with ReLU activation
x = layers.Dense(256, activation='relu', name='encoder_dense_2')(x)
# Output layer for the mean of the latent distribution
z_mean = layers.Dense(latent_dim, name='z_mean')(x)
# Output layer for the log-variance of the latent distribution
z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
# Construct the Keras Model for the encoder
# It takes encoder_inputs and outputs z_mean and z_log_var
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var], name='encoder')
return encoder
# Instantiate the encoder
encoder = build_encoder(input_shape, latent_dim)
encoder.summary() # Print a summary of the encoder's architecture
Defining the Decoder Network
The decoder network takes a sample from the latent space (\(z\)) as input and reconstructs an image. It will also consist of several dense layers, culminating in an output layer with a sigmoid activation to produce pixel values between 0 and 1.
# Create the decoder model
def build_decoder(latent_dim, image_size):
# Input layer for the latent vector
latent_inputs = keras.Input(shape=(latent_dim,), name='decoder_input')
# First hidden dense layer with ReLU activation
x = layers.Dense(256, activation='relu', name='decoder_dense_1')(latent_inputs)
# Second hidden dense layer with ReLU activation
x = layers.Dense(512, activation='relu', name='decoder_dense_2')(x)
# Output layer to reconstruct the image, using sigmoid activation
# This ensures pixel values are between 0 and 1
decoder_outputs = layers.Dense(image_size, activation='sigmoid', name='decoder_output')(x)
# Construct the Keras Model for the decoder
# It takes latent_inputs and outputs the reconstructed image
decoder = keras.Model(latent_inputs, decoder_outputs, name='decoder')
return decoder
# Instantiate the decoder
decoder = build_decoder(latent_dim, image_size)
decoder.summary() # Print a summary of the decoder's architecture
Constructing the VAE Model
Now, we combine the encoder and decoder into a single VAE model. We will also implement the reparameterization trick within this VAE class. Keras allows us to create custom models by subclassing `keras.Model`.
# Define a custom sampling layer for the reparameterization trick
class Sampling(layers.Layer):
"""
Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.
"""
def call(self, inputs):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
# Sample epsilon from a standard normal distribution
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
# Calculate z using the reparameterization trick
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
# Create the VAE model by combining encoder, sampling, and decoder
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
# Keep track of the total loss components for monitoring
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
# Define the metrics that will be tracked during training
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
# The training step where forward pass, loss calculation, and backpropagation occur
def train_step(self, data):
with tf.GradientTape() as tape:
# Forward pass: encode the input data
z_mean, z_log_var = self.encoder(data)
# Apply the reparameterization trick to get a latent sample z
z = Sampling()([z_mean, z_log_var])
# Decode z to reconstruct the input data
reconstruction = self.decoder(z)
# Calculate reconstruction loss
# Binary Cross-Entropy is suitable for pixel values between 0 and 1
reconstruction_loss = tf.reduce_mean(
keras.losses.binary_crossentropy(data, reconstruction)
)
# Scale reconstruction loss by image_size to make it comparable to KL loss
reconstruction_loss *= image_size
# Calculate KL divergence loss
# The formula for KL divergence between N(mu, sigma^2) and N(0, 1)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
# Combine the two loss components to get the total VAE loss
total_loss = reconstruction_loss + kl_loss
# Compute gradients with respect to the total loss
grads = tape.gradient(total_loss, self.trainable_weights)
# Apply gradients to update the model's weights
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# Update the metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
# Return a dictionary mapping metric names to their current value
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
# Instantiate the VAE model
vae = VAE(encoder, decoder)
Implementing the Loss Function
As detailed in the theoretical section, the VAE's loss function comprises two parts: reconstruction loss and KL divergence loss. We implemented these directly within the `train_step` method of our `VAE` class.
- Reconstruction Loss: For the MNIST dataset, where pixel values are normalized between 0 and 1, Binary Cross-Entropy (BCE) is an appropriate choice. We use `keras.losses.binary_crossentropy` and then take the mean over the batch. It's often scaled by the `image_size` to ensure it's on a similar scale to the KL divergence term, especially when the BCE is calculated per pixel.
- KL Divergence Loss: The formula \(-0.5 \sum (1 + \log \sigma^2 - \mu^2 - \sigma^2)\) is directly translated into code. `tf.square(z_mean)` calculates \(\mu^2\), and `tf.exp(z_log_var)` calculates \(\sigma^2\) from \(\log \sigma^2\). We sum this over the latent dimensions for each sample and then take the mean over the batch.
These two losses are summed to form the `total_loss`, which the optimizer then minimizes.
Training Loop
With the VAE model defined and its loss function implicitly handled by the `train_step`, we can now compile and train the model. We will use the Adam optimizer, a popular choice for its efficiency.
# Compile the VAE model with an optimizer
vae.compile(optimizer=keras.optimizers.Adam())
# Train the VAE model on the training data
# We'll train for 50 epochs with a batch size of 128
print("\nStarting VAE training...")
vae.fit(x_train, epochs=50, batch_size=128)
print("VAE training complete.")
During training, the `fit` method will iterate over the dataset for a specified number of `epochs`. For each batch of data, the `train_step` method we defined will be executed, performing the forward pass, calculating the loss, and updating the model's weights. The `metrics` property in our `VAE` class ensures that the `total_loss`, `reconstruction_loss`, and `kl_loss` are tracked and reported.
Generating New Data
One of the most exciting capabilities of a VAE is its ability to generate novel data samples. After training, the decoder has learned to map points from the latent space back into the data space. Since the KL divergence term has regularized the latent space to resemble a standard normal distribution, we can simply sample random vectors from \(\mathcal{N}(0, 1)\) and pass them through the decoder to generate new data.
# Function to plot generated digits
def plot_latent_space(vae, n=30, figsize=15):
# Display a grid of n*n digits generated by sampling from the latent space
digit_size = 28
scale = 1.0
figure = np.zeros((digit_size * n, digit_size * n))
# Linearly spaced coordinates corresponding to the 2D plot of digit classes
# in the latent space
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
# Sample a latent vector from a standard normal distribution
# For a 2D latent space, we create a vector [xi, yi]
z_sample = np.array([[xi, yi]])
# Predict the reconstructed image using the decoder
x_decoded = vae.decoder.predict(z_sample)
# Reshape the output to 28x28 image
digit = x_decoded[0].reshape(digit_size, digit_size)
# Place the digit in the figure grid
figure[
i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size,
] = digit
plt.figure(figsize=(figsize, figsize))
start_range = digit_size // 2
end_range = n * digit_size + start_range + 1
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.imshow(figure, cmap="Greys_r")
plt.show()
# Generate and display new digits
print("\nGenerating new digits from latent space...")
plot_latent_space(vae)
print("Digit generation complete.")
In this `plot_latent_space` function, we iterate through a grid of points in our 2-dimensional latent space (since `latent_dim = 2`). For each point, we treat it as a latent vector \(z\), pass it through the trained `vae.decoder`, and then display the resulting reconstructed image. This allows us to visualize how the latent space maps to different digits and observe the smooth transitions between them.
Conclusion
Congratulations! You have successfully navigated the theoretical landscape and practical implementation of Variational Autoencoders. We began by understanding the foundational concepts of standard Autoencoders and their limitations. We then delved into the core innovation of VAEs: their probabilistic approach to the latent space, the ingenious reparameterization trick, and the dual-component loss function comprising reconstruction loss and KL divergence.
Through our step-by-step implementation using TensorFlow and Keras, you have seen how to build an encoder that outputs distribution parameters, a decoder that reconstructs data, and a VAE model that elegantly integrates the reparameterization trick and the combined loss. Finally, you learned how to leverage the trained VAE to generate novel data samples by simply querying its regularized latent space.
Variational Autoencoders are not just powerful generative models for images; their principles extend to various domains, including anomaly detection, data imputation, and even drug discovery. The ability to learn a smooth, continuous, and interpretable latent representation of data opens up a myriad of possibilities for research and application.
We encourage you to experiment further: try different network architectures, explore various datasets, or even delve into more advanced VAE variants. The world of generative models is vast and exciting, and your understanding of VAEs is a significant step forward in mastering it. Keep learning, keep building, and keep innovating!
Addendum: Full Running Example Code
This addendum provides the complete, self-contained Python code for the Variational Autoencoder discussed in this article. You can run this code directly to load the MNIST dataset, train the VAE, and generate new handwritten digits.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
# --- 1. Data Loading and Preprocessing ---
print("Loading and preprocessing MNIST data...")
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
# Calculate the size of a flattened image (28 * 28 = 784)
image_size = x_train.shape[1] * x_train.shape[2]
# Reshape the images from 28x28 to a flattened vector of 784 pixels
# Normalize pixel values to be between 0 and 1 by dividing by 255.0
x_train = np.reshape(x_train, [-1, image_size]).astype('float32') / 255.0
x_test = np.reshape(x_test, [-1, image_size]).astype('float32') / 255.0
print("Data loading and preprocessing complete.")
# --- 2. Model Parameters ---
# Define the input shape for our images (flattened 28x28 = 784)
input_shape = (image_size,)
# Define the dimensionality of the latent space.
# A 2D latent space allows for easy visualization.
latent_dim = 2
# --- 3. Encoder Network Definition ---
def build_encoder(input_shape, latent_dim):
"""
Constructs the encoder part of the VAE.
It takes an image and outputs the mean and log-variance of the latent distribution.
"""
# Input layer for the flattened image
encoder_inputs = keras.Input(shape=input_shape, name='encoder_input')
# First hidden dense layer with 512 units and ReLU activation
# ReLU (Rectified Linear Unit) is a common activation function
x = layers.Dense(512, activation='relu', name='encoder_dense_1')(encoder_inputs)
# Second hidden dense layer with 256 units and ReLU activation
x = layers.Dense(256, activation='relu', name='encoder_dense_2')(x)
# Output layer for the mean (mu) of the latent distribution
# No activation function here, as mean can be any real value
z_mean = layers.Dense(latent_dim, name='z_mean')(x)
# Output layer for the log-variance (log_sigma_squared) of the latent distribution
# No activation function here. Using log-variance helps ensure positivity
# of variance and improves numerical stability during training.
z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
# Create the Keras Model for the encoder
# It maps the input image to the z_mean and z_log_var vectors
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var], name='encoder')
return encoder
# Instantiate the encoder
encoder = build_encoder(input_shape, latent_dim)
print("\nEncoder Architecture:")
encoder.summary()
# --- 4. Decoder Network Definition ---
def build_decoder(latent_dim, image_size):
"""
Constructs the decoder part of the VAE.
It takes a latent vector and reconstructs an image.
"""
# Input layer for the latent vector (z)
latent_inputs = keras.Input(shape=(latent_dim,), name='decoder_input')
# First hidden dense layer with 256 units and ReLU activation
x = layers.Dense(256, activation='relu', name='decoder_dense_1')(latent_inputs)
# Second hidden dense layer with 512 units and ReLU activation
x = layers.Dense(512, activation='relu', name='decoder_dense_2')(x)
# Output layer to reconstruct the image (784 pixels)
# Sigmoid activation ensures that pixel values are between 0 and 1,
# suitable for binary cross-entropy reconstruction loss.
decoder_outputs = layers.Dense(image_size, activation='sigmoid', name='decoder_output')(x)
# Create the Keras Model for the decoder
# It maps the latent vector to the reconstructed image
decoder = keras.Model(latent_inputs, decoder_outputs, name='decoder')
return decoder
# Instantiate the decoder
decoder = build_decoder(latent_dim, image_size)
print("\nDecoder Architecture:")
decoder.summary()
# --- 5. Reparameterization Trick Layer ---
class Sampling(layers.Layer):
"""
Custom Keras layer that implements the reparameterization trick.
It takes the mean (z_mean) and log-variance (z_log_var) of the latent distribution
and samples a latent vector (z). This makes the sampling process differentiable.
"""
def call(self, inputs):
z_mean, z_log_var = inputs
# Get batch size and dimensionality of the latent space
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
# Sample epsilon from a standard normal distribution (mean=0, variance=1)
# This introduces the stochasticity needed for the VAE.
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
# Calculate z using the reparameterization formula: z = mu + sigma * epsilon
# tf.exp(0.5 * z_log_var) calculates the standard deviation (sigma)
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
# --- 6. VAE Model Definition ---
class VAE(keras.Model):
"""
Custom Keras Model for the Variational Autoencoder.
It combines the encoder, the sampling layer, and the decoder.
It also defines the custom training step including the VAE loss components.
"""
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
# Initialize Keras metrics to track loss components during training
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
# Define the list of metrics that will be reported by the model
@property
def metrics(self):
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
# Custom training step for the VAE
def train_step(self, data):
# Use tf.GradientTape to record operations for automatic differentiation
with tf.GradientTape() as tape:
# Forward pass through the encoder
z_mean, z_log_var = self.encoder(data)
# Apply the reparameterization trick to get a latent sample 'z'
z = Sampling()([z_mean, z_log_var])
# Forward pass through the decoder to reconstruct the input
reconstruction = self.decoder(z)
# Calculate Reconstruction Loss: Binary Cross-Entropy
# This measures how well the VAE reconstructs the original input.
# We scale it by image_size to make it comparable to the KL loss.
reconstruction_loss = tf.reduce_mean(
keras.losses.binary_crossentropy(data, reconstruction)
)
reconstruction_loss *= image_size
# Calculate KL Divergence Loss
# This regularizes the latent space, forcing it to be close to a standard normal distribution.
# Formula: -0.5 * sum(1 + log_var - mean^2 - exp(log_var))
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
# Total VAE Loss is the sum of reconstruction loss and KL divergence loss
total_loss = reconstruction_loss + kl_loss
# Compute gradients of the total loss with respect to the model's trainable weights
grads = tape.gradient(total_loss, self.trainable_weights)
# Apply the computed gradients to update the model's weights using the optimizer
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
# Update the state of the metrics
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
# Return a dictionary of current metric values
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
# Instantiate the VAE model
vae = VAE(encoder, decoder)
# --- 7. Model Compilation and Training ---
print("\nCompiling VAE model...")
# Compile the VAE model with the Adam optimizer
# Adam is a popular and effective optimizer for deep learning models.
vae.compile(optimizer=keras.optimizers.Adam())
print("VAE model compiled.")
print("\nStarting VAE training...")
# Train the VAE model on the training data
# We train for 50 epochs (passes over the entire dataset)
# with a batch size of 128 (number of samples processed before updating weights)
vae.fit(x_train, epochs=50, batch_size=128)
print("VAE training complete.")
# --- 8. Generating New Data and Visualization ---
def plot_latent_space(vae, n=30, figsize=15):
"""
Generates and displays a grid of digits by sampling points
from the 2D latent space and passing them through the decoder.
"""
digit_size = 28 # Each MNIST digit is 28x28 pixels
scale = 1.0 # Scale for sampling from the latent space
# Create an empty canvas to store the generated digits
figure = np.zeros((digit_size * n, digit_size * n))
# Generate linearly spaced coordinates for the 2D latent space grid
# These coordinates will be used as the z[0] and z[1] values
grid_x = np.linspace(-scale, scale, n)
grid_y = np.linspace(-scale, scale, n)[::-1] # Reverse y-axis for proper display
print(f"Generating a {n}x{n} grid of digits...")
# Iterate through the grid to generate and place digits
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
# Create a latent vector [xi, yi] for the current grid point
z_sample = np.array([[xi, yi]])
# Use the decoder to predict the reconstructed image from the latent sample
x_decoded = vae.decoder.predict(z_sample, verbose=0)
# Reshape the flattened output (784 pixels) back to a 28x28 image
digit = x_decoded[0].reshape(digit_size, digit_size)
# Place the generated digit into the larger figure canvas
figure[
i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size,
] = digit
# Plot the generated grid of digits
plt.figure(figsize=(figsize, figsize))
# Set up tick marks and labels for the latent space axes
start_range = digit_size // 2
end_range = n * digit_size + start_range + 1
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0] (Latent Dimension 1)")
plt.ylabel("z[1] (Latent Dimension 2)")
plt.title("Generated Digits from 2D Latent Space")
# Display the figure using a grayscale colormap
plt.imshow(figure, cmap="Greys_r")
plt.show()
print("Digit generation and visualization complete.")
# Generate and display new digits from the learned latent space
plot_latent_space(vae)