Why Machine Learning Needs Causality

A Gentle Guide to Causal Inference with Machine Learning Pt. 1

Jakob Runge
Causality in Data Science
8 min readFeb 13, 2023

--

When computers fail and natural intelligence excels, we stand at the boundary of current machine learning research. But what makes animals so superior? Why is it possible that we can identify a tree as a tree no matter the light or image quality and why is it so challenging for ML algorithms to adapt to these changes? The answer lies in the difference between correlation and causation. In this blog series, we give a gentle introduction for newcomers to causal inference.

Inferring knowledge about causal relationships is at the heart of the scientific enterprise and the field of causal inference is about learning those from observational data. In today’s data-driven world causal inference hopes to lift machine learning onto its next level, solving many of the hardest and seemingly ever-existing problems such as overfitting and prediction under new conditions.

In the end, machine learning boils down to a large pattern recognition system, which leverages the power of associations and correlations in the setting of independent and identically distributed data (i.i.d). In other words, machine learning sees things, assumes them to be true relationships, and extrapolates its decisions on the future based on these past experiences.
This recognition and exploitation of associations leads to undoubtedly powerful predictions, yet it encounters a fundamental challenge. Real-world data often does not fulfill the conditions necessary for stable prediction outcomes across scenarios. In computer vision for example, the test set distribution might differ fundamentally from the training distribution due to changes in the light conditions, the quality of the camera, or its viewpoint. These changes, however, contain the potential for completely defeating or at least heavily impacting the accuracy and usefulness of the trained model.

On the other hand, understanding the mechanism between a cause and its effect enables causal models to cover a range of distributions, one for each possible intervention, or perform reliable predictions under a change in the environment. In contrast, purely statistical learning models, which are based on the i.i.d.- assumption, only allow for one general population distribution and do not offer this flexibility, making them less generalizable.

As predictions find increasing implementations in human decision-making, this well-known generalization problem of association-based machine learning goes beyond worse accuracy levels. It might result in fundamentally wrong decisions or conclusions about reality. A machine-learned system that detects the positive association between the European stork population and birth rates might result in a powerful prediction algorithm for birth rates. However, changing the stork population in a country will obviously not affect birth rates even though purely statistical models may predict so.
Of course, data scientists would never train such a model you might think. And this is true for the stork example. But indeed there are many situations where our judgment or intuition fails. The crucial point is that we don’t know when this is the case.

Let’s have a look at this with an example. Consider the following equations with three variables and jointly independent noise terms as a given ground truth system:

In reality, we do not know these systems, we call them structural causal models (SCMs). All we have (besides some theory) is observational data from the system. So let’s act as if this was the case and generate a sample from the joint distribution (X, Y, Z) as our “observations”.

# Defining the systems variables
import numpy as np
import random
import matplotlib.pyplot as plt
random.seed(10)
Z = np.random.normal(loc = 3, scale=1, size = 300).reshape(-1, 1)
Y = 3*Z + np.random.normal(loc = 1, scale=0.5, size = 300).reshape(-1, 1)
X = 2*(Z-1) + np.random.normal(loc = 3, scale=1, size = 300).reshape(-1, 1)

Given this information, we want to get a solid understanding of the underlying system. Say, by plotting our data we see how our variables X and Y relate to each other.

# Visualize the association
plt.scatter(X, Y, color='blue')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

Consequently, we could fancy that those two share a connection that we want to explore with a basic regression. Regressing Y on X will consequently result in a decent prediction model although no actual direct functional relation between X and Y exists (you can check that this is the case in the ground truth SCM system written above). Note, that this is independent of the kind of ML algorithm we use. If we only have the information above (X and Y), even the most powerful models will learn and use a relationship for their predictions that is not a causal one.

# Regress Y on X
from sklearn.linear_model import LinearRegression
linreg = LinearRegression()
result = linreg.fit(X, Y)

# Evaluate the regression
r_sq = result.score(X, Y)
print(r_sq) #output:0.78214..

If we would be able to intervene in the system and use the resulting interventional data, we could uncover such false conclusions.

Assume we can do that. Below, we modify our system by intervening on X.

# The effect of intervening on X
np.random.seed(10)
Z = np.random.normal(loc = 3, scale=1, size = 300).reshape(-1, 1)
Y = 3*Z + np.random.normal(loc = 1, scale=0.5, size = 300).reshape(-1, 1)
X = np.random.normal(loc = 3, scale=1, size = 300).reshape(-1, 1) #intervene on X

If X and Y would have a true causal connection, the intervention on X should have an impact on Y. The mechanism that connects X and Y should be invariant. However, when performing a second regression on this interventional data we can make a totally different observation:

linreg2 = LinearRegression()
result2 = linreg2.fit(X, Y)
# Evaluate the regression
r_sq_2 = result2.score(X, Y)
print(r_sq_2) #output: 0.0025...

As can be seen above, we have an extremely low R-squared value, indicating a very poor fit of our regression. In other words, there seems to be no connection between X and Y.

This gets even more obvious when plotting the data:

# Visualize the association
plt.scatter(X, Y, color='blue')
plt.xlabel('X')
plt.ylabel('Y')
plt.show()

Imagine, you trained a model that performs its predictions based on the false relationship detected earlier and is now confronted with the above change in its environment. It is doomed to produce rubbish outcomes.

A quick answer to this problem could be: “Then let's do as many experiments, and/or feed our model with as much data as possible, hoping the model will account for all eventualities and gain a reliable understanding of its application context”.

And in some cases, it can work. However, such a strategy is limited to fields with extremely large data capacities and no major risk of distribution shifts (e.g. NLP). In all others, it is either monetarily infeasible or simply morally unacceptable to conduct any kind of experiment. This gives us a large space of cases in which machine learning stands on a weak foundation and no experiments are possible.

That’s exactly where the toolkit of causal inference comes into play. Although being unable to intervene directly in the system, we can use expert assumptions and algorithms to achieve a similar effect that in the end allows us to obtain a deep understanding of the causal system we observe.

So what’s the moral of the story? Simply using our data as it is, performing some feature engineering, and feeding them into powerful algorithms might lead to good results, but these results could stand on shaky ground since no causal knowledge of the actual system is obtained. Incorporating the tools of causal inference into machine learning brings us further.

Causal Inference Powered Machine Learning

„If we wish to incorporate learning algorithms into human decision making, we need to trust that the predictions of the algorithm will remain valid if the experimental conditions are changed.“ (Schölkopf et al., 2021).

From this perspective, the current evolution of machine learning towards grounding models in fundamental underlying structures rather than correlation-based associative links is not altogether very surprising. This fundamental underlying structure can be offered by causal inference, a set of mathematical tools to answer questions about causal relationships based on observational or experimental data by utilizing assumptions about the underlying system.

Causal inference can be distinguished into two fundamental tasks.

  1. Utilizing qualitative causal knowledge for making causal predictions
  2. Learning causal graphs (causal discovery)

Causal machine learning emerges at the crossroads of conventional machine learning and causal inference. On the one hand, identifying the causal graph of a prediction problem provides essential support for better machine learning models. On the other hand, established machine learning methods can be utilized within the framework of causal inference to learn causal relations or even causal variables.

Both can be understood as Causal Machine Learning.

In this series of blog posts, we will give a gentle introduction to this growing and quickly developing field of new research — starting with the basics first. Join us on our journey and discover how models might look to provide answers to the seemingly simple question of “Why”.

About the authors:

Kenneth Styppa is part of the Causal Inference group at the German Aerospace Center’s Institute of Data Science. He has a background in Information Systems and Entrepreneurship from UC Berkeley and Zeppelin University, where he has engaged in both startup and research projects related to Machine Learning. Besides working together with Jakob, Kenneth worked as a data scientist at BMW and currently pursues his graduate degree in Applied Mathematics and Computer Science at Heidelberg University. More on: https://www.linkedin.com/in/kenneth-styppa-546779159/

Jakob Runge heads the Causal Inference group at German Aerospace Center’s Institute of Data Science in Jena and is chair of computer science at TU Berlin. The Causal Inference group develops causal inference theory, methods, and accessible tools for applications in Earth system sciences and many other domains. Jakob holds a physics PhD from Humboldt University Berlin and started his journey in causal inference at the Potsdam Institute for Climate Impact Research. The group’s methods are distributed open-source on https://github.com/jakobrunge/tigramite.git. More about the group on www.climateinformaticslab.com

Quoted from:

Schölkopf, Bernhard, et al. “Toward causal representation learning.” Proceedings of the IEEE 109.5 (2021): 612–634.

--

--

Jakob Runge
Causality in Data Science

Jakob Runge heads the Causal Inference group at German Aerospace Center’s Institute of Data Science in Jena and is a guest professor at TU Berlin.