Made Easy — Mixture Density Network for multivariate Regression

Dave Cote, M.Sc.
15 min readFeb 8, 2022

--

In this article, I will first explain briefly what a MDN is and then give you the python code to make your own MDN model with only a few lines of code.

Before we start…

If you prefer to follow the jupyter notebook: https://github.com/CoteDave/blog/blob/master/Made%20easy/MDN%20regression/Made%20easy%20-%20MDN%20regression.ipynb

If you want to play around and experiment yourself with my experimental custom MDN python class model : https://github.com/CoteDave/blog/blob/master/Made%20easy/MDN%20regression/mdn_model.py

First… Mixture Density Network regression … huh What !??? Let’s start slow and explain separately and quickly all the main terms here:

1. Regression

2. Density

3. Mixture Density

4. Network

5. Regression Network

6. Density Network

7. Mixture Density Network regression

8. Practice fun — Univariate dataset with custom MDN class

9. Practice fun — Multivariate dataset with custom MDN class

10. Conclusion

1REGRESSION

First, the term “regression”. If you have click on this article, I assume you know what regression is all about, but if not, here a really quick and simple explanation from Jason Brownlee of Machine Learning Mastery (Great guy by the way!):

“Regression predictive modeling is the task of approximating a mapping function (f) from input variables (X) to a continuous output variable (y) […] A regression problem requires the prediction of a quantity. A problem with multiple input variables is often called a multivariate regression problem […] For example, a house may be predicted to sell for a specific dollar value, perhaps in the range of $100,000 to $200,000 […] ” (Source: https://machinelearningmastery.com/classification-versus-regression-in-machine-learning/#:~:text=A%20regression%20problem%20requires%20the,called%20a%20multivariate%20regression%20problem.)

Here another visual explanation that differentiate a classification problem from a regression problem:

source: https://towardsdatascience.com/regression-or-classification-linear-or-logistic-f093e8757b9c

A Last example for the luck…

source: https://vinodsblog.com/2018/11/08/classification-and-regression-demystified-in-machine-learning/)

2DENSITY

Okay, now what the term “density” is all about? Here a quick “cheesy” example:

Suppose you are delivering pizza for pizza-hut. You decide to capture the time (in minutes) of each delivery you just made. After 1000 deliveries, you decide to visualize your data to see how well you performed. Here is the result :

source: https://statisticsbyjim.com/basics/normal-distribution/

Here is the “density” of the distribution of your pizza delivery times data. Soo, in average, it took you 30 minutes per delivery (The peak in the graph). It also says that 95% of the time (2 standard deviations (SD)), your delivery took between 20 and 40 minutes to accomplish. Here, the density kind of represents the “frequency” of the time results. The difference between “frequency” and “density” is that:

· Frequency: If you draw a histogram under this curve and counts all the bins, it will sum up to any whole number (depending on the total number of observations captured in your dataset).

· Density: If you draw a histogram under this curve and counts all the bins it will always sum up to 1. We can also call this curve the probability density function (pdf).

In statistical terms, this is a beautiful normal / Gaussian distribution. This normal distribution have two parameters:

· The mean

· The standard deviation (According to Wikipedia: “Standard deviation is a number used to tell how measurements for a group are spread out from the average (mean), or expected value. A low standard deviation means that most of the numbers are close to the average. A high standard deviation means that the numbers are more spread out.”)

The change of mean and standard deviation will affect the shape of the distribution. For Example:

source: https://towardsdatascience.com/probability-concepts-explained-probability-distributions-introduction-part-3-4a5db81858dc

Of course, there exists a variety of different distributions types with different kind of parameters. For example:

3MIXTURE DENSITY

Okay now let’s have a look at those 3 distributions:

Source: https://medium.com/analytics-vidhya/learning-from-multimodal-target-mixture-density-network-94891d4e357e

If we take this bimodal distribution (Also called a general distribution):

Mixture Density Network use the assumption that any general distribution like this bimodal distribution can be broken down into a mixture of normal distributions (The mixture can also be customize with other type of distributions like Laplace for example):

Soo if you add up those 2 normal distributions (red and blue), where each distribution have is own mean and standard deviation, it will gives back the bimodal general distribution (purple).

4NETWORK

Network for Artificial neural network. Soo yes, a Mixture Density Network is a type of artificial neural network. Here a classical example of a neural network:

source: www.SuperDataScience.com

You have the input layer (yellow), the hidden layer (green) and the output layer (red). (I assume here that everyone know what an artificial neural network is, if not, go on google to find out!).

5REGRESSION NETWORK

source: www.SuperDataScience.com

A regression network is an artificial neural network where the goal is to learn to output a continuous value given some input features. In the example above, given the age, sex, education and other features, the artificial neural network learns to predict the salary.

6DENSITY NETWORK

source: www.SuperDataScience.com

A density network is an artificial neural network where the goal is not to simply learn to output a single continuous value, but to learn to output the distribution parameters (here, mean and standard deviation), given some input features. In the example above, given the age, sex, education and other features, the artificial neural network learns to predict the mean and the standard deviation of the expected salary distribution. Predicting the distribution versus a single value have advantages like being able to gives uncertainty boundaries with the prediction. This is a “Bayesian” approach of resolving a regression problem. Here is a great illustration of predicting the distribution of each expected continuous values:

Source: https://towardsdatascience.com/bayesian-thinking-estimating-posterior-distribution-for-linear-regression-data-ketchup-2f50a597eb06

Another great illustration that shows us the distribution of the expected values, for each predicted instance:

https://engineering.taboola.com/predicting-probability-distributions/

7MIXTURE DENSITY NETWORK

Finally! A mixture density network is an artificial neural network where the goal is to learn to output all the parameters (here, the mean, standard deviation and Pi) of all the distribitions mixed in the general distribution, given the specific input features. The new parameter “Pi” is the mixture parameter that gives the weights/probability of a given distribution in the final mixture.

source: www.SuperDataScience.com

The final result:

source: www.SuperDataScience.com

8PRACTICE FUN — UNIVARIATE DATASET WITH CUSTOM MDN CLASS

Okay, enough of this theoretical stuff, let’s do it!

Here we have this famous “half-moon” dataset:

Just by looking at the data, we can see that there is two overlapping clusters.

If we plot the density distribution of the target value (y):

We have a nice “ghost-like” multimodal distribution (general distribution). If we try a standard Linear Regression on this dataset to predict y with X:

Did not work very well!

Let’s now try a nonlinear model (Radial Basis Function Kernel Ridge regression):

Better, but not quite yet!

The main reason why both model did not make it is because, if you look on the X-axis, multiple different y values exists for the same X value… more specifically, there seems to exists more than one possible y distribution for the same X. The regression model just tried to find the optimal function that minimize the error and did not took in consideration a mixture of densities, which is the case! Those X in the middle do not have a unique Y solution, they have two possible solutions, one high and one low!

Let’s now try a MDN model to see what this beast can do! For that, I implemented for you a quick nice and easy to use “fit-predict”, “sklearn alike” customized python MDN class. Here is the link to my python code if you want to use it yourself (Be aware: this MDN class is experimental and has not been extensively tested): https://github.com/CoteDave/blog/blob/master/Made%20easy/MDN%20regression/mdn_model.py

To be able to use this class, you will need to have sklearn, tensorflow probability, Tensorflow < 2 (If you want to make this MDN model class compatible with TF2, feel free to contribute!), umap and hdbscan (for custom visualization class function).

· n_mixtures: number of distribution mixtures to use by the MDN. If set to -1, it will “auto” find the optimal number of mixtures using a gaussian mixtures model (GMM) and a HDBSCAN model on X and y.

· dist: Distribution type to use in the mixture. For now, there are two choices available; “normal” or “laplace”. (Based on some experimentations, Laplace distribution gave me better results than normal distribution).

· input_neurons: Number of neurons to use in the input layer of the MDN

· hidden_neurons: Architecture of the hidden layers. Of the MDN. List of neurons per hidden layer. This parameter give you the ability to choose the number of hidden layers and the number of neurons per hidden layer.

· gmm_boost: Boolean. If set to True, will add cluster features to the dataset.

· optimizer: Optimization algorithm to use.

· learning_rate: Learning rate of the optimization algorithm

· early_stopping: To avoid overfitting when training. This trigger will decide when to stop training when no change in metric over a given number of epochs.

· tf_mixture_family: Boolean. If set to True, will use the tf_mixture family (Recommended): The Mixture object implements batched mixture distributions. The mixture model is defined by a Categorical distribution (the mixture) and a python list of Distribution objects.

· input_activation: Activation function of the input layer

· hidden_activation: Activation function of the hidden layer

Now that our MDN model is fitted on the data, let’s samples from the mixture density distribution and plot the probability density function:

Wow! Our MDN model fitted pretty well the true general distribution!

We can broke down the final mixture distribution into each distribution to see how it looks:

Let’s sample again some Y data using our learned mixture distribution and this time, plot the generated samples against the true ones:

Again, pretty close from reality!

Given X, we can also generate multiple batch of sample to produce statistics like quantile, mean, etc.:

Now, we can plot the mean of each of our learned distributions, with their respective mixture weights (pi):

Wow, MDN just nailed it, again!

As we have the mean and the standard deviation from each distribution, we can also plot the uncertainties with the full tail; let’s say we plot the mean with a confidence interval of 95%:

We can also mix the distribution together and, when we have multiple y distributions for the same X, we choose the most probable mixture using the highest Pi parameter value:

· Y_preds = for each X, choose the Y mean of the distribution with the max probability/weight (Pi parameter)

If we add the 95% confidence interval:

Not ideal here, because there are clearly two distinct clusters overlapping in our data with near equivalent density. The errors would be higher than a standard regression model. This also means that maybe there is an important feature missing in our dataset that could help to avoid the cluster overlapping in a higher dimensionality.

We can also choose to mix the distribution using the Pi parameters and the means of all the distribution together:

· Y_preds = (mean_1 * Pi1) + (mean_2 * Pi2)

Which gives:

If we add the 95 confidence interval:

This option gives us pretty much the same result as a non-linear regression model, mixing everything to minimize the distance between the dots and the function. This gives a single function that pass through the two clusters.

In this very particular case, my favorite option would be to assume that in some region of the data, X have more than one Y while in other regions; only one of the mixture is used. That give use something like:

For example, when X = 0, there is two possible distinctive Y solutions from each mixture. When X = -1.5, there is a unique Y solution from mixture 1. Depending on the use case or business context, an action or decision can be triggered when more than one solution exists for the same X.

With this option, the rows are duplicated when there is an overlapping distribution (If both mixture probabilities are >= given probability threshold). This give something like:

With the 95% confidence interval:

Dataset rows passed from 2500 to 4063 and the final predictions dataset looks like:

In this data table, we see that, for example, when X = -0.276839, Y can be 1.43926 (with 0.351525 of probabilities of mixture_0) but can also be -0.840593 (with 0.648475 of probabilities of mixture_1).

The instances with multiple distributions gives also an important information that something is going on in the data and that maybe it needs more analysis. Maybe it is some data quality problems or maybe it indicates that an important feature is missing in the dataset!

Another example: “traffic scene prediction is a good candidate for where Mixture Density Networks can be used. In traffic scene prediction, we need a distribution over behaviors an agent can exhibit — for example, an agent could turn left, turn right, or go straight. Thus, Mixture Density Networks can be used to represent “behaviors” in each of the mixture it learns, where the behavior consists of a probability and trajectory ((x, y) coordinates up to some time horizon in the future).(source: http://blog.adeel.io/tag/mixture-density-networks/)“

http://blog.adeel.io/tag/mixture-density-networks/

9PRACTICE FUN — MULTIVARIATE DATASET WITH CUSTOM MDN CLASS

Okay finally, does MDN perform well with multivariate regression problems? Let’s find out !

First of all, let’s have a look at the dataset we will use:

· age: age of primary beneficiary

· sex: insurance contractor gender, female, male

· bmi: Body mass index, providing an understanding of body, weights that are relatively high or low relative to height, objective index of body weight (kg / m ^ 2) using the ratio of height to weight, ideally 18.5 to 24.9

· children: Number of children covered by health insurance / Number of dependents

· smoker: Smoking

· region: the beneficiary’s residential area in the US, northeast, southeast, southwest, northwest.

· charges: Individual medical costs billed by health insurance. This is the target that we want to forecast

I took this data dictionary from https://www.kaggle.com/mirichoi0218/insurance. If you go to this link, you can also have an interactive visualization of the dataset.

If you want to do some expriments, the dataset is also available here: https://github.com/stedy/Machine-Learning-with-R-datasets

The problem statement is: Can you accurately predict insurance costs (charges)?

To see a quick exploratory data exploration on this dataset, you can read one of my previous article here: https://www.linkedin.com/pulse/demonstrating-power-feature-engineering-part-ii-how-i-dave-c%C3%B4t%C3%A9/

Now, let’s import the dataset:

…and do some basic data preparation!:

The data is ready to feed the MDN model!

Let’s now predict the test dataset using the “best mixture probabilities (Pi parameter) strategy” and plot the result (y_pred vs y_test):

For this visualization technic, you can read one of my previous article: https://www.linkedin.com/pulse/visualization-trick-multivariate-regression-problems-dave-c%C3%B4t%C3%A9/

Wow! This is completely insane! With a R2 of 89.09 and a MAE of 882.54, The MDN just destroyed it!

Let’s plot the fitted vs true distribution for fun:

Almost exactly the same!

If we broke down the mixture model:

There is a total of six different distributions in the general mixture distribution.

We can generate multivariate samples from the fitted mixture model (a PCA is applied to visualize the result in 2D):

The generated samples are very close to the real ones!

We can predict from each distribution if we want:

We can check if there are overlapping cluster in the dataset:

y_pred_overlaps have the same number of rows than the X_test dataset, so there is no overlapping! (We can adjust the pi parameter threshold for the sensibility)

10CONCLUSION

· MDN did a great job in the univariate regression dataset where two clusters overlapped on each other and where X could have more than one Y output, compared to linear or nonlinear classical ML models.

· MDN did also a great job in the multivariate regression problem and can compete with popular model like XGBoost

· MDN is a great and unique tool to have in your ML toolbox and can resolve particular problems that other models can’t do (capable of learning from data obtained from a mixed distribution)

· As MDN learn the distribution, you can also compute uncertainty with your predictions or generate new samples from the learned distribution

I hope this article helped you in your Machine Learning journey!

--

--

Dave Cote, M.Sc.

Data Scientist in an insurance company. More than 10 years in Data Science and for delivering actionnable « Data-Driven » solutions.