Flow-based Generative Models

Samuel Tober
KTH AI Society
Published in
6 min readFeb 2, 2022

--

Let us start with a simple question, what are generative models?

Based on the intuitive meaning, generative models are responsible for creating new instances of data. Discriminative models on the other hand are responsible for discriminating or distinguishing between various kinds of instances of data.

To explain in simple terms, for example, generative models would be assigned the task of generating new instances of a bike or car, and discriminative models would be assigned the task of distinguishing between a bike and a car.

Incase you want to get into the mathematics behind this, generative models can be thought of as learning the joint probability P(X, Y) where X denotes the data instances and Y denotes the labels. The model tries to learn the overall distribution of the dataset. The discriminative model on the other hand learns P(Y | X).

Assuming that we have observed data D from an underlying distribution pᴰ(x), the goal of a generative model is to approximate this distribution with a model distribution, pᵐᵒᵈᵉˡ(x). Once we have an approximation of the full joint we can perform various downstream inference tasks. Generally, there are three fundamental inference tasks in generative modeling: Density estimation, which equates to finding pᴰ(x) for a given sample x, sampling, i.e. generating new data from the underlying distribution, and finally representation learning, that is, to learn useful feature representations of the data. Trade-offs in regard to these inference tasks have led to the development of very diverse families of generative models. At one end of the spectrum we have likelihood based models like the variational autoencoder (VAE)¹, the normalizing flow² and auto-regressive models like the transformer³. In contrast, there are generative models like the seminal generative adversarial network (GAN) that do not explicitly model the likelihood⁴.

Overview of deep generative model

The focus of this blog post will be to introduce flow based models, first from a theoretical perspective, and finally giving a practical example through an actual implementation.

Theoretical Background

The goal of a normalizing flow model is to map simple distributions to complex ones. This is done through a series of bijective transformations, f, transforming a simple distribution (typically Gaussian) into one with more expressive power such that:

which can be written more compactly:

where z is a random variable drawn from a simple distribution. Using this recursive transformation and the transformation theorem, one can derive an expression for the likelihood and log-likelihood:

This is good news since it means that we can tractably, use the explicit log-likelihood as an optimization objective, in contrast to for example the VAE, where a lower bound approximation on the log-likelihood is used.

Now, so far we have been agnostic as to what the transformation f actually looks like. In general, we are not constrained in our choice here but in order for the log-likelihood to be efficiently computable we would like the determinant of our transformation to be also efficiently computable. Moreover, we require that f is bijective so that its inverse exists. This is important if we want to be able to sample from our model:

Planar Flows

In this post we will look at one type of flows, namely planar flows. Let us define the bijective function f for this type of flow:

where u, h and b are free parameters. With our flow in place, we have to compute the determinant to able to use the transformation theorem:

The flow defined by the transformation above modifies the initial density z by applying a series of contractions and expansions in the direction perpendicular to the hyperplane wᵀz+b = 0, hence we refer to these maps as planar flows⁵.

Implementation

We will explore planar flows by using them to recover the density of a toy dataset, connecting back to the tasks of generative models, this would fall under density estimation.

In this example we will use the very simple checkerboard dataset. Below we have split the dataset into training and validation sets, and our goal will be to learn the density p(x) of this data.

Right: Training dataset, Left: validation dataset

We start of with a Gaussian random variable, and then use the planar flow formulation from above to try and capture the true data distribution:

# assign a normal distribution
base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros([2], tf.float32))
# create a flow
bijectors = []
for i in range(0, layers):
bijectors.append(PlanarFlow(input_dimensions=2, case="density_estimation"))
bijector = tfb.Chain(bijectors=list(reversed(bijectors)), name='chain_of_planar')x_dist = tfd.TransformedDistribution(
distribution=base_dist,
bijector=bijector
)

Having set up our flow model, we proceed by training it using batch gradient descent:

global_step = []
train_losses = []
val_losses = []
min_val_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32) # high value to ensure that first loss < min_loss
min_train_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)
min_val_epoch = 0
min_train_epoch = 0
delta_stop = 1000 # threshold for early stopping
t_start = time.time() # start time# start training
for i in range(max_epochs):
for batch in batched_train_data:
train_loss = train_density_estimation(x_dist, opt, batch)
# ensure that w.T * u > -1 (invertibility)
for bijector in x_dist.bijector.bijectors:
bijector._u()
if i % int(100) == 0:
val_loss = -tf.reduce_mean(x_dist.log_prob(val_data))
global_step.append(i)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"{i}, train_loss: {train_loss}, val_loss: {val_loss}")
if train_loss < min_train_loss:
min_train_loss = train_loss
min_train_epoch = i
if val_loss < min_val_loss:
min_val_loss = val_loss
min_val_epoch = i
checkpoint.write(file_prefix=checkpoint_prefix)elif i - min_val_epoch > delta_stop: # no decrease in min_val_loss for "delta_stop epochs"
break
if i % int(1000) == 0:
# plot heatmap
plot_heatmap_2d(x_dist, -4.0, 4.0, -4.0, 4.0)
train_time = time.time() - t_start

Looking at the training/validation loss we can see that the training is stable and the model is not overfitting:

Finally, we inspect the results of the model with the lowest loss and plot the density heatmap:

# load best model with min validation loss
checkpoint.restore(checkpoint_prefix)
# perform on test dataset
t_start = time.time()
test_loss = -tf.reduce_mean(x_dist.log_prob(test_data))
test_time = time.time() - t_start
# plot density estimation of best model
plot_heatmap_2d(x_dist, -4.0, 4.0, -4.0, 4.0, name=save_dir)
Heatmap of recovered density using best planar flow model, bright areas indicate higher density

and as we can see, the model is quite successful in estimating the density of our distribution.

From this blog we hope that you now have a basic understanding of what generative models are, and why they are so cool. We also hope that now you have a clear understanding of flow based generative models. Stay tuned for more on generative modelling.

Author

Samuel Tober is a member of the KTH AI Society, MSc student in Computer Science at the KTH Royal Institute of Technology, currently doing his master thesis at Klarna. You can reach him on LinkedIn or by email at samuel@kthais.se

Vishal Nedungadi is a member of the KTH AI Society, MSc student in Machine Learning at the KTH Royal Institute of Technology. You can reach him on LinkedIn or by email at vishalned@gmail.com

References:

  1. Diederik P Kingma and Max Welling. Auto-Encoding Variational Bayes. 2014. arXiv:1312.6114 [stat.ML]
  2. Laurent Dinh, David Krueger, and Yoshua Bengio. NICE: Non-linear Independent Components Estimation. 2015. arXiv:1410.8516 [cs.LG].
  3. Ashish Vaswani et al. Attention Is All You Need. 2017. arXiv:1706.03762 [cs.CL].
  4. Ian J. Goodfellow et al. Generative Adversarial Networks. 2014. arXiv:1406.2661 [stat.ML].
  5. Danilo Jimenez Rezende and Shakir Mohamed.Variational Inference with Normalizing Flows. 2016.arXiv:1505.05770 [stat.ML].

--

--

Samuel Tober
KTH AI Society

Msc student in computer science with more than 2 years of experience with analytics and machine learning. Graduating summer 2022.