Learning Interpretable Latent Representations with InfoGAN
A tutorial on implementing InfoGAN in Tensorflow
In this week’s post I want to explore a simple addition to Generative Adversarial Networks which make them more useful for both researchers interested in their potential as an unsupervised learning tool, as well as the enthusiast or practitioner who wants more control over the kinds of data they can generate. If you are new to GANs, check out this earlier tutorial I wrote a couple weeks ago introducing them. The addition I want to go over in this post is called InfoGAN, and it was introduced in this paper published by OpenAI earlier this year. It allows GANs to learn disentangled latent representations, which can then be exploited in a number of useful ways. For those interested in the mathematics behind the technique, I high recommend reading the paper, as it is a theoretically interesting approach. In this post though, I would like to provide a more intuitive explanation of what InfoGANs do, and how they can be easily implemented in current GANs.
The structure of a GAN is as follows: a generator (G) and discriminator (D) are updated in competing fashion in order to produce realistic data samples from latent variables. The discriminator is optimized to differentiate between the generator’s created samples and the true data samples, while the generator is optimized to produce samples most likely to trick the discriminator. In order to generate a sample, the generator is fed with a randomly noise vector z. This z is the set of latent variables used by the generator to produce samples. These latent variables can be thought of as the seeds which contain all the information needed to grow our data sample. Once the GAN has been trained to convergence, each variable in z should hypothetically corresponds to some aspect of the generated sample. Ideally each of these variables would not only correspond to the data, but do so in a semantically meaningful way. In image data for example, we would expect some variables to adjust lighting, other to adjust object position, and others to adjust colors. In reality however, once a GAN is trained the z variables often fail to correspond to any semantically decipherable aspects of an image.
This is where the InfoGAN comes in. By introducing a mutual information maximization objective into the GAN, we can have the network learn to disentangle the representations of the latent variables in addition to produce visually convincing samples. Using this technique we can have the GAN learn meaningful latent variables that can theoretically correspond to aspects of an image such as the presence or absence of certain objects, or positional features of a scene. One of the more impressive aspects of InfoGAN is that unlike Conditional GAN models which rely on supervised learning to introduce meaningful latent codes, InfoGAN is entirely unsupervised, and the model learns to generate semantically meaningful latent variables automatically. All we have to provide it with is a template for the kinds of variables we want it to learn.
With a few additions, we can easily implement InfoGAN on top of the DCGAN model I discussed here. In that example we built a GAN which could generate MNIST digits. With InfoGAN we can have it learn a representation in which each of the ten digits corresponds to a single latent variable.
Little adjustment needs to be made to the generator network in order for it to be used as an InfoGAN. In addition to the original z vector, we simply add a set of c vectors which correspond to the latent representations we want our model to learn. For categorical representations such as type of object, we will use one-hot vectors, and for continuous representations such as object rotation, we will use float variables between -1 and 1, drawn from a random uniform distribution. These c variables are concatenated with the z vector and all are fed into the generator together.
The discriminator network is where the bulk of the change happens. It is here that the constraints are determined and enforced between the latent variables and the properties of the generated image. In order to impose these constraints, we add a second top layer to the discriminator. We refer to this as a Q-network. Not to be confused with the Q-networks from reinforcement learning, we use Q due to the theoretical connection with variational inference techniques which utilizes a q-distribution to approximate a true-data p-distribution. This Q-network shares all lower-layers with the discriminator, and only the top layers are different. After a fully connected layer, the Q-network contains softmax layers for categorical variables, and a tanh layer for continuous variables. We then compute the loss for categorical outputs with c*log(ĉ) and continuous outputs with |c -ĉ|². In both cases c refers to the assigned latent value given to the generator, and ĉ refers to the value determined by the Q-network. By minimizing the distance between c and ĉ, we can ensure that the generator is learning to produce samples which utilize the latent information in a semantically meaningful way.
We now just ensure we generate appropriate latent c variables to attach to the z vector when training the GAN. During optimization, in addition to updating the discriminator network using the discriminator loss, and updating the generator using the generator loss, we also update all variables using the Q-network losses.
Returning to the MNIST dataset, we can define a new categorical latent variable with 10 possible values, corresponding to the 10 digits. We can also define two latent continuous variables, with values ranging from -1 to 1. Once the training process has converged, the network learns to produce different images depending on the value of those latent variables.
This technique can be applied to any dataset in which the underlying data distribution has some properties which could be meaningfully broken up categorically or continuously. Although MNIST is relatively simple, many datasets often have multiple obvious categorical distinctions which an InfoGAN could be trained to learn. In addition to more purposeful data generation, InfoGAN can also be used as a first step in other supervised learning problems. For example, the Q-network portion of the GAN trained in the way described above could serve as a classifier for new real-world data.
For an implementation of the entire thing in Tensorflow, see the Gist below:
When I first learned about InfoGAN it was clearly a promising idea that took me a while to wrap my head around. I hope this tutorial has made understanding it a little clearer for those new to the concept. If you’ve read this and still want to know more, I recommend the original OpenAI paper, as well as their implementation, both of which are excellent.
If this post has been valuable to you, please consider donating to help support future tutorials, articles, and implementations. Any contribution is greatly appreciated!