Deep Causal Generative Modelling — A Brief Tutorial

Harish Ramani
The Startup
Published in
9 min readSep 29, 2020

Integrating a causal model into a deep learning architecture.

This article summarizes a project integrating a causal directed graphical model with a variational autoencoder (VAE). Deep generative models like VAEs can generate synthetic images that are comparable to a set of training images.

The content of the images in the training data follow rules. At the most basic level, they follow the laws of physics. However, within a specific domain, the rules can simplify to set of abstract entities and relationships between entities. For example, your training data could have a simple set of entities like “boy”, “girl”, and “dog”, and a set of interactions including “boy walks dog” and “dog bites girl” and “girl greets boy.”

Suppose you wanted to specify these rules when generating from the model? For example, you might want to see what an image looks like if “boy greets boy.” You could add examples of those interactions to your training data. What if it were “girl bites dog”? It might be challenging to find such an image for inclusion in the training data if working with real-world images.

One solution is to build a deep causal generative model. The entities and the rules of their interactions come from knowing the domain you wish to model and are made explicit in the model architecture. The mapping of those entities and relationships to pixels in an image is up to the neural network elements of the model architecture — in the case of a VAE, this is the decoder. Given an image, the encoder maps us back to the abstractions describing the entities and interactions between entities characterized within the image, and we can reason about them causally.

This tutorial demonstrates proof of concept for this flavor of modeling using procedurally generating images of fights between characters one might find in a role-playing video game.

~ Robert Osazuwa Ness, Ph.D — Altdeep.ai

Target Audience and Potential Use cases of Causal Generative Models

This article assumes the reader has some prerequisite knowledge on how to build generative models for images and/or different architectures like Variational Auto Encoder and Generative Adversarial Networks. Causal generative modeling referred to in this article isn’t about causal discovery but rather integrating a causal model with deep generative models like variational autoencoders.

Two boys on top of a cliff seeing sunrise

Artists could potentially use this model to evaluate queries like, “How would the image look if it was night instead of day” ?. If it had been night, maybe the birds won’t be there and many things may aesthetically appear different. The current image generation models don’t take causality into account. This could potentially be applied to medical imaging — “How would the MRI look if I change from the current settings to a different setting?”, to mass spectroscopy — “How would the mass spectroscopy imaging look if a different compound or element was present ?”.

Note: An in-depth tutorial with sample code can be found here https://linkinnation1792.gitbook.io/causal-scene-generation/ and the code can be found in the causalML repo.

Data generation process and Dataset

The situation we are considering is one where two characters are interacting with each other in which one character instigates some action and the other character reacts to the action. This situation is best illustrated using characters in a game. We represent the data generation process (i.e the scene where the two-game characters interact) through a Directed Acyclic Graph.

Directed Acyclic Graph (DAG) describing the scene.
Explanations for the DAG

Procedural Generation

This step is needless if you already have the images that need to be trained. Since we are constructing the dataset from scratch, a procedural generation scheme is created to generate images for all possible combinations of characters, its variant, and the different actions and reactions.

A Satyr variant attacking another Satyr variant and it got hurt Satyr variant and a Golem variant attacking each other

Probabilistic queries

The main purpose of using a causal model is to evaluate probabilistic queries. There are three types of probabilistic queries that one can ask. Judea Pearl refers to it as a causal ladder, a hierarchy of three types of problems of increasing difficulty. They are Prediction, Intervention, and Counterfactual.

Prediction

To do prediction within a system, we only need to ask questions about the system as it currently as. Therefore, it’s sufficient to have the joint distribution P over all variables V in the system. If we have the joint distribution P, then we can answer any questions of the form, “Given that the variables X⊆V are observed to be x, then what is the probability that the variables Y⊆V are equal to y?”

Query evaluated in our dataset: How would the image look if the actor, reactor, and its type were set to a specific value?

Intervention

By performing an exogenous intervention in a system, we change its distribution. The original distribution P may no longer be valid. Thus, to answer questions about intervention, we need a family of distributions {P_{X=x}}XV. Once we have this family, we can answer any question of the form, “If we force the variables X⊆V to x, then what is the probability that the variables Y⊆V are equal to y?”

Query evaluated in our dataset: How would the image look if we intervene on the action and set it to be Attacking and infer upstream variables like attacking skills of the actor?

Counterfactual

At the counterfactual level, we are allowed to ask questions of the form, “Given that variables Z⊆V were observed to be z if variables X⊆V were forced to be x, then how likely is that variables Y⊆V would have been equal to y?”

Note: No counterfactual example is evaluated for our dataset as it is still under implementation

Implementation

Model Overview during training

Technology stack: Pyro (probabilistic programming language from uber) , pytorch, bnlearn and gRain (R packages to deal with bayesian network structure learning)

All probabilistic programs are built up by composing primitive stochastic functions and deterministic computation. In our case, the data generating process is encoded into the DAG and is implemented to a stochastic function, conventionally named, model. In this stochastic function, we define each of the nodes in the DAG to be sampled from a specific distribution. In our case, all the nodes except the image are discrete variables and hence sampled from a categorical distribution. Since we have training labels we can use them in conjunction while sampling from the nodes. If there are any learn-able parameters, as in our case the encoder and decoder neural networks, we need another stochastic function named guide to help learn these parameters. Inference algorithms in pyro, such as stochastic variational inference, use the guide functions as approximate posterior distributions. Guide functions must satisfy two criteria to be valid approximations of the model. One, all the unobserved sample statements that appear in the model must appear in the guide. Second, the guide has the same signature as that of the model, i.e. it takes the same arguments.

actor = pyro.sample("actor", 
dist.OneHotCategorical(self.cpts["character"]), obs=actorObs)

Here, we sample the actor from a categorical distribution, one hot encoded with a certain prior probability as mentioned in the CPT section and condition with the observed label, using the obs argument.

In the guide function, we sample unobserved nodes of the DAG. In our example, we don’t observe the strength, attack, and defense attributes of the actor and the reactor in the image. Hence, we use the guide function and learn their posterior distributions. We compute the conditional probability of the unobserved nodes, indexed by the values of their parents and children nodes. For actor_strength, the parents of the node are actor and actor type and the children are action. These 3 entities are observed in our training data.

One of the guide statement is mentioned below

actor_strength = pyro.sample("actor_strength",
dist.Categorical(
self.inverse_cpts["action_strength"][action, actor_type,actor]))

Note: Please refer the DAG if you’re confused. The variables in guide function are inferred by taking the parents and the children nodes of the variable.

Variational Auto Encoder Architecture

Variational Autoencoders are directed probabilistic graphical models whose posteriors are approximated by a neural network with an autoencoder like architecture. The Autoencoder architecture comprises of an encoder unit, which reduces the large input space to a latent domain, usually of lower dimension than of input space, and a decoder unit which reconstructs the input space from the latent representation.

Labels + Latent produce Image

The brief outline of the causal variational autoencoder network is given below.

Causal VAE higher outline

The Encoder and Decoder have Convolution Units as we deal with images and we use Stochastic Variational Inference to train this model.

Inference Mode

Model Overview during Inference Mode

In inference mode, we use the trained decoder network in conjugation with the latent node to generate an image for various probabilistic queries. In inference mode, instead of doing inference using MCMC or HMC, we pre-compute the posterior distributions using analytic methods using the gRain package in R. This is not possible in all cases.

Results of Probabilistic queries

Before we test the probabilistic queries, we show the reconstruction capabilities of the causal generative model to show that it has learned the distributions of the images to a certain degree.

How would the image look if the actor, reactor, and its type were set to a specific value?

In this prediction/condition query, we set the actor character to satyr with type 1 variant and the reactor character to golem with type 3 variant.

Image from the decoder.
The original image is generated by procedural generation.

In addition to getting an image, we can infer the unobserved (from image) attributes using the causal model.

How would the image look if we intervene in the action and set it to be Attacking?

In this example, we will see the difference between the condition and the intervention statements in terms of the probability distribution. The intervention we apply is on the actor’s action and set it to attack. Now, we infer on the upstream nodes to the actor’s action like the actor’s attacking capability.

intervention_2_bn <- mutilated(dfit, list(AACT="Attack"))
intervention_2_grain <- as.grain(intervention_2_bn)
Intervention distribution of Actor Attack
The conditional distribution of Actor Attack

We can see that the attacking capability is different in the intervention distribution than in conditional distribution. Like the above, we infer for all the nodes necessary and we sample from that distribution.

Image generated from intervention distribution.

Conclusion

Probabilistic queries like condition and intervention queries were successfully evaluated by integrating a causal dag along with a deep generative model. There are many potential use cases of this approach like medical imaging. There are a few reasons why scaling this approach would be troublesome.

  • Figuring out a Correct Directed Acyclic Graph is difficult even in normal cases and evaluating the feasibility of creating DAG’s for Images (explaining the image generation process) is untested.
  • Inference is incredibly hard and its made worse by having a deep generative model with, potentially, millions of parameters.

There’s a lot to gain by integrating causality into machine learning models like increased reasoning capabilities. Causality and causal inference techniques can be applied to almost all branches of machine learning like Computer Vision, Natural Language Processing, Reinforcement Learning, etc.

Robert Osazuwa Ness writes at altdeep. check out the course repo at causalML

Any suggestions or feedback on the article are welcome. You can reach out to me via my LinkedIn or Twitter.

--

--