Investing in financial products involves risk to your capital.

Close Navigation
Learn more about IBKR accounts
TensorFlow: Variational Autoencoder (VAE) for MNIST Digits

TensorFlow: Variational Autoencoder (VAE) for MNIST Digits

Posted August 26, 2024 at 12:06 pm
Sang-Heon Lee
SHLee AI Financial Model

Excerpt

This post demonstrates the implementation of TensorFlow code for Variational Autoencoder (VAE) using a well-established example with MNIST digit data.

VAE in TensorFlow

Variational Autoencoder (VAE)

The Variational Autoencoder (VAE) is a generative model that allows us to learn a probabilistic representation of data.

The VAE architecture consists of an encoder and a decoder. The encoder maps input data to a probability distribution in a latent space, while the decoder generates data from samples drawn from the latent space.

The core concept of VAE is the latent space, which is represented by the mean and variance of a Gaussian distribution. The equations for VAE are as follows:

The loss function for VAE includes a reconstruction loss and a regularization term to encourage the latent space to be normally distributed.

I’m omitting the derivation of the aforementioned loss function as there are abundant educational resources on Google. Numerous high-quality materials provide a better explanation than I can offer.

The reparameterization trick allows the training of generative models with stochastic elements while maintaining differentiability. It is crucial when working with continuous latent variables.

here, μ and σ are mean and standard deviation of the distribution of the latent variable zϵ is sampled from a fixed distribution, typically a standard Gaussian distribution, N(0,1).

Python Jupyter Notebook Code

A well-established example of VAE’s application is with MNIST digits. The following code reads MNIST data and performs some preprocessing.

import numpy as np
import matplotlib.pyplot as plt
 
from keras.datasets import mnist
from keras.layers import Input, Lambda, Dense
from keras.models import Model
from keras import backend as K
from keras.utils import plot_model
from keras.losses import binary_crossentropy
 
# network parameters
rec_dim=784
input_shape = (rec_dim,)
int_dim = 512
lat_dim = 2
 
# Load the MNIST data
(x_tr, y_tr), (x_te, y_te) = mnist.load_data()
 
# normalize values of image pixels between 0 and 1f
x_tr = x_tr.astype('float32') / 255.
x_te = x_te.astype('float32') / 255.
 
# 28x28 2D matrix --> 784x1 1D vector
x_tr = x_tr.reshape((len(x_tr), np.prod(x_tr.shape[1:])))
x_te = x_te.reshape((len(x_te), np.prod(x_te.shape[1:])))
 
print(x_tr.shape, x_te.shape)

The following code includes both the encoder and decoder. The encoder portion involves sampling latent factors using their mean and variance through the reparameterization trick.

#=======================
# Encoder
#=======================
# Z sampling function
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    
    # Reparameterization Trick
    # draw random sample ε from Gussian(=normal) distribution
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    
    return z_mean + K.exp(0.5 * z_log_var) * epsilon
 
# Input shape
inputs = Input(shape=input_shape)
enc_x  = Dense(int_dim, activation='relu')(inputs)
 
z_mean    = Dense(lat_dim)(enc_x)
z_log_var = Dense(lat_dim)(enc_x)
 
# sampling z
z_sampling = Lambda(sampling, (lat_dim,))([z_mean, z_log_var])
 
# encoder model has multi-output so a list is used
encoder = Model(inputs,[z_mean,z_log_var,z_sampling])
encoder.summary()
 
#=======================
# Decoder
#=======================
# Input of decoder is z
input_z = Input(shape=(lat_dim,))
dec_h   = Dense(int_dim, activation='relu')(input_z)
outputs = Dense(rec_dim, activation='sigmoid')(dec_h)
 
# z is the input and the reconstructed image is the output
decoder = Model(input_z, outputs)
decoder.summary()

After constructing the VAE model, which encompasses both the encoder and decoder, the VAE loss, also referred to as the Evidence Lower Bound (ELBO), is calculated as the combination of the reconstruction loss and the Kullback-Leibler (KL) loss. Notably, in the case of beta-VAE, the KL loss is adjusted using a scaling factor, beta, to strike a balance between these two components.

#=======================
# VAE model
#=======================
outputs = decoder(encoder(inputs)[2])
vae = Model(inputs, outputs)
 
#--------------------------------------------------
# VAE_loss = ELBO
#--------------------------------------------------
# (1)Reconstruct loss (Marginal_likelihood) : Cross-entropy 
rec_loss = binary_crossentropy(inputs,outputs)
rec_loss *= rec_dim
# (2) KL divergence(Latent_loss)
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = -0.5*K.sum(kl_loss, 1)
# (3) ELBO
vae_loss = K.mean(rec_loss + kl_loss)
#--------------------------------------------------
 
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')
vae.summary()
 
history = vae.fit(x_tr, x_tr, shuffle=True, 
                  epochs=30, batch_size=64, 
                  validation_data=(x_te, x_te))

Visit SHLee AI Financial Model for details on how to visualize the training and validation losses across epochs.

Originally posted on SHLee AI Financial Model blog.

Disclosure: Interactive Brokers

Information posted on IBKR Campus that is provided by third-parties does NOT constitute a recommendation that you should contract for the services of that third party. Third-party participants who contribute to IBKR Campus are independent of Interactive Brokers and Interactive Brokers does not make any representations or warranties concerning the services offered, their past or future performance, or the accuracy of the information provided by the third party. Past performance is no guarantee of future results.

This material is from SHLee AI Financial Model and is being posted with its permission. The views expressed in this material are solely those of the author and/or SHLee AI Financial Model and Interactive Brokers is not endorsing or recommending any investment or trading discussed in the material. This material is not and should not be construed as an offer to buy or sell any security. It should not be construed as research or investment advice or a recommendation to buy, sell or hold any security or commodity. This material does not and is not intended to take into account the particular financial conditions, investment objectives or requirements of individual customers. Before acting on this material, you should consider whether it is suitable for your particular circumstances and, as necessary, seek professional advice.

This website uses cookies to collect usage information in order to offer a better browsing experience. By browsing this site or by clicking on the "ACCEPT COOKIES" button you accept our Cookie Policy.