The theoretical Intuition Behind the Expectation-Maximization (EM) Algorithm for Mixture of Experts Framework

Bakary Badjie
6 min readSep 12, 2023

--

As we outlined the high-level intuition behind a mixture of experts framework in our Intuitions Behind Mixture of Experts Ensemble Learning article, we highlighted that the gating network is trained alongside the experts using the concept of expectation-maximization (EM). We also mentioned that the gating network and the experts are trained together such that the gating network learns when and how to trust or assign each expert to predict a particular input. This training procedure was traditionally implemented using the concept of EM.

The EM algorithm is a statistical and computational technique used in machine learning and statistics for estimating unknown parameters in probabilistic models, especially when dealing with hidden or unobserved variables in a neural network. In high-level terms, EM can be explained through the following steps:

Expectation (E-step):

In the first step, EM starts with a guess to obtain the model’s initial parameters. Then, these initial parameters are used to estimate the values of the hidden or unobserved variables of the model. These estimated values represent our “expectation” of the hidden variables given the observed data.

Maximization (M-step):

In the second step, EM uses the estimated values of the hidden variables obtained in the E-step to update and improve the model’s parameters. This maximizes the likelihood of the observed data under the model with respect to the parameters. This step seeks to find the parameters that make the observed data most probable, given the estimated values of the hidden variables.

Iterative Process:

EM repeats the E-step and M-step iteratively until the model achieves reasonable convergence. Each iteration refines the estimates of both the latent variables and the model parameters. As the iterations progress, the estimates typically improve and become more accurate, and the algorithm approaches a local maximum of the likelihood function.

Convergence and Final Estimates:

EM continues iterating until a convergence criterion is met, which indicates that the parameters have stabilized and are no longer changing significantly under any external influence. The final estimates of the model parameters provide a more accurate representation of the underlying data distribution, even when dealing with missing or incomplete data.

An illustration of high-level concept of an Expectation-Maximization (EM) Algorithm

The Expectation-Maximization algorithm is a two-step process that iteratively refines its estimates of hidden variables and model parameters to maximize the likelihood of observed data. It is particularly useful in situations where some data is unobserved or missing, and it provides a principled approach to learning the underlying structure of probabilistic models. EM has applications in various fields; it is mostly used in clustering in a mixture of experts framework or density estimation in a data distribution.

When training a mixture of experts framework, the goal is to improve the training accuracy and generalization of both the individual expert models and the gate model. Improving these accuracies solely depends on the cost function at every epoch. If we compute the partial derivatives of the cost function with respect to the output of the gate network, this gives us a gradient that shows us the direction in which to update the weights or shows us which weights need to be updated. This makes the gate model better in assigning the right expert for a particular input. Similarly, if we compute the partial derivatives of the cost function with respect to the output of the experts, we obtain the gradient, which tells us how to update the weights for that expert. This makes each expert perform better in their respective tasks.

Maximum Likelihood Estimation

Maximum likelihood estimation (MLE) is an approach to density estimation for a dataset by searching across probability distributions and their parameters.

It is a general and practical approach that underlies many machine learning algorithms, although it requires that the training dataset be complete, e.g., all relevant interacting random variables are present. MLE becomes intractable if there are variables that interact with those in the dataset but were hidden or not observed, so-called latent variables.

The expectation-maximization algorithm is an approach for performing MLE in the presence of latent variables (missing or unobserved values of the data). It does this by estimating the values for the latent variables, optimizing the model, and repeating these two steps until convergence criteria are met or satisfied. It is a practical and general approach and is most commonly used for density estimation with missing data, such as in clustering algorithms like the Gaussian Mixture Model.

MLE is challenging in the presence of latent variables in a dataset. However, EM provides an iterative solution to MLE with latent variables. Using the EM algorithm, Gaussian mixture models can become an approach to density estimation where the parameters of the distributions are fitted with the data.

The problems Introduced by Latent Variables During Maximum Likelihood Estimation

A typical modeling problem involves estimating a joint probability distribution for a dataset.

Density estimation involves selecting a probability distribution function and the parameters of that distribution that best explain the joint probability distribution of the observed data.

There are many techniques for solving this problem, although a MLE is considered a common approach.

MLE involves treating the problem as an optimization or search problem, where we seek a set of parameters that best fit the data sample's joint probability.

One of the limitations of MLE is that it assumes that the dataset is complete or fully observed. This does not mean that the model has access to all data; instead, it assumes that all variables that are relevant to the problem are present.

This is not always the case! There may be datasets where only some of the relevant variables can be observed, and some cannot, and although they influence other random variables in the dataset, they remain hidden.

Many real-world problems have hidden features (latent variables) that are not observable in the data available for model training.

It is to be noted that conventional MLE does not work well in the presence of latent variables.

If we have missing data and/or latent variables, computing the MLE becomes extremely hard because this technique assumes that the data is complete and ready for use.

The EM algorithm is one such approach to address this limitation of MLE. The EM algorithm can be widely applied in various domains, although it is most well known in unsupervised machine learning problem statements such as density estimation and clustering.

Perhaps the most discussed application of the EM algorithm is for clustering within a mixture experts models.

Gaussian Mixture Model and the EM Algorithm

A mixture model is a model comprised of an unspecified combination of multiple probability distribution functions.

A statistical procedure or learning algorithm is used to estimate the parameters of the probability distributions to best fit the density of a given training dataset.

The Gaussian Mixture Model, or GMM for short, is a mixture model that uses a combination of Gaussian (Normal) probability distributions and requires the estimation of the mean and standard deviation parameters for each.

There are many techniques for estimating the parameters for a GMM, although a maximum likelihood estimate is perhaps the most common.

Consider the case where a dataset is comprised of many points that happen to be generated by two different processes. The points for each process have a Gaussian probability distribution, but the data is combined, and the distributions are similar enough that it is not obvious to which distribution a given point may belong.

The processes used to generate the data point represent a latent variable, e.g., process 0 and process 1. It influences the data but is not observable. As such, the EM algorithm is an appropriate approach to use to estimate the parameters of the distributions.

In the EM algorithm, the E-step would estimate a value for the process latent variable for each data point, and the M-step would optimize the parameters of the probability distributions in an attempt to best capture the density of the data. The process is repeated until a good set of latent values and a maximum likelihood are achieved that fits the data.

  • E-Step. Estimate the expected value for each latent variable.
  • M-Step. Optimize the parameters of the distribution using maximum likelihood.

We can imagine how this optimization procedure could be constrained to just the distribution means or generalized to a mixture of many different Gaussian distributions.

--

--

Bakary Badjie
0 Followers

I am a PhD Student in Computer Science. My research interest is focused on Enhancing Robustness of Machine Learning Models for Autonomous Driving Systems