Parameter Estimation for Latent Dirichlet Allocation explained with Collapsed Gibbs Sampling in Python

Chris
12 min readMar 6, 2021

--

Latent Dirichlet Allocation (LDA), first published in Blei et al. (2003) is one of the most popular topic modeling approaches today. LDA is a simple and easy to understand model based on a generative process. However, the estimation of its parameters is not as simple. Blei et al. (2003) propose an approximate inference technique based on variational methods from Bayesian statistics to estimate the parameters of LDA. Although the approach is efficient for larger datasets, it is hard to derive the equations for Variational Inference. A less efficient but easier to understand approach based on collapsed Gibbs Sampling was proposed in Griffiths and Steyvers (2004). This makes it a good starting point to understand parameter estimation for LDA.

The goal of this article is to explain parameter estimation for LDA in simple words using collappsed Gibbs Sampling. You should be familiar with LDA to understand the article in all its details. The approach is illustrated with a Python implementation derived from Griffiths and Steyvers (2004). The full Python code for this article can be found on GitHub.

Model and Notation

The core of LDA is its generative process to characterize a corpus of documents. The idea is to represent documents as a mixture over latent topics θ. Each topic is described by a distribution over words ϕ (Blei et al. 2003). Blei et al. (2003) explain the assumptions made for the document generating process in a similar way:

  1. Choose N ∼ Poisson(ξ).
  2. Choose θ ∼ Dir(α).
  3. Choose ϕ ∼ Dir(β).
  4. For each of the N words w_n:
    (a) Choose a topic z_n ∼ Multinomial(θ).
    (b) Choose a word w_n from p(w_n|z_n,ϕ), a multinomial probability distribution with one trial which is a categorical distribution.

A document is described by N words. For each document a new document topic distribution θ with size equal to the number of chosen topics K is drawn from a Dirichlet distribution. The Dirichlet prior is parameterized by the hyperparameter α which directly influences the shape of θ. The Dirichlet distribution is the conjugate prior of the multinomial distribution which is an important property to make inference tractable. In the next step, for each word in the document the topic assignment z_n is drawn from the multinomial distribution θ. After determining the topic assignment z_n, a word w_n from the vocabulary V is drawn. The probability p(w_n|z_n,ϕ) of drawing a word w from V depends on the topic assignment z_n and ϕ. Each topic k ∈ {1, …, K} is related to exactly one per-topic word distribution ϕ_k which assigns probabilities to each word in the vocabulary V.

In general, the goal is to reverse the generative process to make estimates about the latent variables z, θ and ϕ based on the observed variable w. In other words, we have to compute the posterior distribution.

Please note that the generative process described in Blei et al. (2003) slightly differs from the definition in this article. Blei et al. (2003) do not directly mention the distribution paramter ϕ in their generative process. This variable is added for simplicity.

Full Conditional Distribution

The approach proposed in Blei et al. (2003) involves to make estimates about the parameters ϕ and θ. In contrast, Griffiths and Steyvers (2004) integrate out ϕ and θ and make only estimates about the word topic assignments z. If we are able to approximate the posterior distribution P(z|w), it is also possible to make estimates about the parameters ϕ and θ which directly depend on z.

Posterior distribution P(z|w) described in Griffiths and Steyvers (2004).

Unfortunately, it is computationally intractable to calculate the sum over P(w, z) as it involves topics^words terms. However, since we know that P(w, z) = P(w|z)P(z) holds which is proportional to our posterior distribution, we can sample from this distribution. Griffiths and Steyvers (2004) integrate out ϕ and θ to obtain the so called full conditional distribution. The full conditional distribution is required to construct a Markov chain by sampling the next state based on the current state. As the next state only depends on the current state and is independent from previous states this forms a Markov chain Monte Carlo approach which is known to converge to the target distribution. In this case, the target distribution is our posterior distribution.

Full conditional distribution described in Griffiths and Steyvers (2004).

The full conditional distribution has a nice intuitive explanation. The first term equals the ratio of word w in topic j. This term is multiplied by the second term which equals the ratio of topic j in document d. The product of both terms is a non-normalized vector which assigns a weight to each topic j.

Simplified explanation of the full conditional distribution without smoothing parameters.

Normalizing this vector gives us the probability of visiting each of the next possible states based on the current state. In the context of LDA, the next state is the assignment of the current word z_i to exactly one out of k ∈ {1, …, K} topics.

Imagine word w has a high ratio in topic j and topic j has a high ratio in document d at the same time. The product of both terms indicates that it gets more likely to sample topic assignment z = j for word w in document d.

To calculate the full conditional distribution we have to keep track of the corresponding corpus statistics in the formula. To do this, we have to allocate four tensors in memory.

Data structures required to collect corpus statistics

On the left side, the overall number of words assigned to each topic is stored. Moreover, we also have to count the number of times each topic was assigned to each word in the vocabulary. Please note that W represents the number of words in the vocabulary V. On the right side, the number of words in each document is counted and below we also keep track of the overall number of words in each document assigned to each topic.

The hyperparameters α and β in the formula are simply used for smoothing and equal our prior belief about the observed counts. In Bayesian statistics it is common to use a prior belief as a first best guess and update it with the observed data. The resulting distribution is called the posterior distribution. Griffiths and Steyvers (2004) assume symmetrical priors for simplicity which indicates that we don’t know anything about the observed counts intially. However, Wallach et al. (2009) propose in Rethinking LDA: Why Priors Matter that an asymmetric Dirichlet prior over the document-topic distribution in general is a better choice for LDA. This favours the more intuitive assumption that documents in practice only make use of one or a few topics and not all available topics K. Therefore, we follow Wallach et al. (2009) in this article and adapt the equations in the Python code to fit a model with an asymmetric prior paramaterized by an asymmetric vector α.

Implementation of the full conditional distribution described in Griffiths and Steyvers (2004).

To summarize, the full conditional distribution is required to calculate the probability of the next states. A single draw from this multinomial distribution determines the next state in the Markov chain.

Collapsed Gibbs Sampling

Collapsed Gibbs Sampling is simple and easy to understand. Let’s have a look at the sampling procedure which mainly consists of the following steps:

  1. Collect and keep track of corpus statistics.
  2. Compute full conditional distribution.
  3. Normalize full conditional distribution.
  4. Draw random sample from posterior to generate new topic assignment z for word w.

For each word occurrence in all documents these steps are repeated several times. It is important to note that the order of words does not matter. However, one drawback of collapsed Gibbs Sampling is its computational effort. We need to go over all documents for a few hundred times. Additionally, we have to throw away the first few hundred examples as it is expected that early states are not very likely to be a good representation of our target distribution. This step is commonly known as burn-in. Furthermore, it is common to consider only every nth sample from the posterior after the burn-in period to reduce autocorrelation between samples. Let’s have a look at an example to make the sampling procedure more intuitive.

Example

The number of topics in this example is chosen to be K = 2. Please note that the hyperparameters α and β are not considered for simplicity.

  1. Collected corpus statistics.

# topic 1/2 assigned to word w: [10, 20]

# words assigned to topic 1/2: [100, 200]

# words in document assigned to topic 1/2: [15, 10]

# words in document: 25

2. Compute full conditional distribution.

word topic ratio = [25, 20] / [100, 200] = [0.25, 0.1]

topic document ratio = [15, 10] / 25 = [0.6, 0.4]

p(z_i|z_-i,w) = [0.25, 0.1] * [0.6, 0.4] = [0.15, 0.04]

3. Normalize full conditional distribution.

p(z_i|z_-i,w) =[0.15, 0.04] / 0.19 = [0.79, 0.21]

4. Draw random sample from posterior to generate new topic assignment z for word w in document d.

z_i ~ Multinomial(n=1, [0.79, 0.21])

The word topic ratio and topic document ratio are both larger for topic 1. The resulting distribution also favours to draw topic 1. Roughly speaking, it is four times more likely to draw topic 1 over topic 2 in the current state.

Example of a random draw: z_i = 0 (sampled index 0 equals topic 1)

Python implementation of collapsed Gibbs Sampling for LDA

The following is a simple Python implementation of collapsed Gibbs sampling for LDA. Please not that some implementation details and most LDA parameters are hidden in the class LDABase. Please note that the presented implementation has some extra features not mentioned so far. First, the hyperparameters for the Dirichlet priors α and β are optimized using Minka’s fixed-point iteration. As we make progress in training and observe new data, we should also update our prior belief about the topic and word distributions. The approach is not described in this article. If you are interested in this approach, please read Minka (2000).

Implementation of collapsed Gibbs Sampling for LDA described in Griffiths and Steyvers (2004).

Model Convergence

Another important feature of the Python code in this article is the calculation of the log likelihood for our model. The log likelihood measures how likely it is that the specified model generated the observed data. In other words, it measures how well the current model explains the observed data. In general, this quantity has a negative range and increasing values indicate a better model fit. The log likelihood of a model can be used to monitor training progress and model convergence. However, this measure is directly influenced by the number of tokens in the corpus. Therefore, perplexity usually is a better choice as it is normalized by the number of tokens and measures the per-word log likelihood. In contrast, perplexity has a positive value range and decreasing values indicate a better model fit.

Next, let’s discuss at a high level how to derive the formula for log likelihood from the equations presented in Griffiths and Steyvers (2004).

Equations required to compute the joint distribution described in Griffiths and Steyvers (2004).

The likelihood of our model equals the joint distribution P(w,z)=P(w|z)P(z). To avoid problems with numerical stability we should take the logartihm and sum up both terms. Remember that when we take the logarithm, division is replaced by subtraction and multiplication is replaced by addition. Please note that gammaln() returns the logarithm of the gamma function.

Log likelihood derived from Griffiths and Steyvers (2004).

Model Training

With what we have learned so far, it is possible to train and evaluate LDA. So let’s train a first model with K = 5 topics. We choose 150 burn-in iterations and 50 sampling iterations from which we select only every 5th sample to infer our model parameters. The model is trained on 21,000 reviews from Amazon (https://www.kaggle.com/lievgarcia/amazon-reviews).

df = pd.read_csv(“amazon_reviews.txt”, sep=”\t”)
texts = df[“REVIEW_TEXT”].values.tolist()
corpus_train = Corpus(texts, max_features=10_000)
lda = LDA(corpus_train, K=5, alpha=”asymmetric”, beta=0.01, samples=50, burnin=150, interval=5, eval_every=1)lda.plot_topic_prior_alpha()

The plots below show the log likelihood and perplexity values measured during training. It can be observed that the log likelihood of the model increases while the perplexity decreases. This is consistent with our expectation. Therefore, it can be concluded that the model converges to a stationary distribution which is our posterior distribution p(z|w).

lda.fit()

Output:

burnin iteration 0 perplexity 11082.6 likelihood -5767872.9
burnin iteration 1 perplexity 9249.0 likelihood -5655861.3
burnin iteration 2 perplexity 8453.6 likelihood -5600168.5
burnin iteration 3 perplexity 7992.5 likelihood -5565429.9
burnin iteration 4 perplexity 7635.2 likelihood -5537107.4
burnin iteration 5 perplexity 7291.4 likelihood -5508570.1
burnin iteration 6 perplexity 6932.4 likelihood -5477303.2
burnin iteration 7 perplexity 6521.2 likelihood -5439429.8
burnin iteration 8 perplexity 6084.9 likelihood -5396545.2
burnin iteration 9 perplexity 5665.3 likelihood -5352291.9
burnin iteration 10 perplexity 5313.0 likelihood -5312531.8
...
sampling iteration 190 perplexity 3096.8 likelihood -4978230.9
sampling iteration 191 perplexity 3095.9 likelihood -4978039.2
sampling iteration 192 perplexity 3099.2 likelihood -4978699.4
sampling iteration 193 perplexity 3097.7 likelihood -4978415.0
sampling iteration 194 perplexity 3093.7 likelihood -4977602.4
sampling iteration 195 perplexity 3092.5 likelihood -4977362.9
sampling iteration 196 perplexity 3097.3 likelihood -4978328.4
sampling iteration 197 perplexity 3096.7 likelihood -4978215.4
sampling iteration 198 perplexity 3092.5 likelihood -4977357.9
sampling iteration 199 perplexity 3094.2 likelihood -4977709.7

Plot log likelihood and perplexity:

lda.plot_metrics()
Change of log likelihood and perplexity during training.

Another interesting feature gives more insights into the convergence of our model. The trace of the marginal topic distribution P(T) tracks the change of the topic distribution over the whole corpus during training. The plot below shows the convergence process of each topic and is commonly known as a trace plot. The model has reached a stationary distribution and can be considered as converged if the gradient of the marginal topic distribution is close to zero. In other words, stationarity is reached if there is no change in P(T) over time.

lda.plot_marginal_topic_dist()
Change of marginal topic distribution P(T) during training.

What we can observe in the trace plot above is that there are three major prominent topics after 200 iterations which have an aggregate ratio of roughly 70 percent in the corpus. All topics except topic 4 reach stationarity after around 150 iterations. Topic 4 reaches stationarity after 100 iterations. What you can also see in the plot is that convergence in collapsed Gibbs Sampling is quite slow. Unfortunately, the approach is also hard to parallelize as the sampling procedure is sequential in nature. This is one major drawback of Markov Chain Monte Carlo.

Topics

Let’s observe the estimates for ϕ obtained during the training process. Remember that ϕ equals the per-topic word distribution p(w|t). It is obtained by drawing samples after the burnin period from the posterior, considering only every 5th sample and taking the average vector.

As you can see below, topic 0 and topic 1 clearly describe positive reviews about products considering price and quality. Furthermore, topic 3 is about the quality of electronic products, e.g. TV, camera and sound.

lda.print_topics()

Output:

p(w|t)	word

Topic #0
0.019 great
0.014 good
0.013 like
0.013 love
0.012 really
0.012 quality
0.011 just
0.010 price
0.009 bought
0.008 nice

Topic #1
0.021 br
0.014 product
0.011 like
0.010 use
0.009 just
0.008 great
0.008 really
0.007 good
0.007 love
0.007 using

Topic #2
0.033 br
0.011 use
0.008 easy
0.007 like
0.007 just
0.007 light
0.007 great
0.006 34
0.006 good
0.006 case

Topic #3
0.039 br
0.009 great
0.009 tv
0.008 use
0.008 good
0.008 sound
0.008 just
0.007 quality
0.006 product
0.006 camera

Topic #4
0.041 br
0.014 book
0.012 game
0.008 34
0.008 movie
0.008 like
0.007 just
0.006 read
0.006 time
0.006 great

Thank you for reading this article. If you liked my article, I would appreciate it if you would show it. In the next article, we compare and analyze different approaches to infer the document topic distribution of unseen documents.

References

Blei, David M.; Ng, Andrew Y.; Jordan, Michael I. (2003): Latent dirichlet allocation. In: the Journal of machine Learning research 3 (pp. 993–1022).

Griffiths, Thomas L.; Steyvers, Mark (2004): Finding scientific topics. In: Proceedings of the National academy of Sciences 101.suppl 1 (pp. 5228–5235).

Minka, T. (2000): Estimating a Dirichlet distribution. Technical report, M.I.T.

Wallach, Hanna M.; Mimno, David M.; McCallum, Andrew (2009): Rethinking LDA: Why priors matter. In: Advances in neural information processing systems (pp. 1973–1981).

--

--

Chris

Data Scientist with focus Natural Language Processing.