Unifying Word Embeddings and Matrix Factorization — Part 3

Kian Kenyon-Dean
Radix
Published in
12 min readJun 24, 2019

Implementing Word2vec as matrix factorization… in TensorFlow 2.0!

TL;DR

In this 3-part blog series we present a unifying perspective on pre-trained word embeddings under a general framework of matrix factorization. The most popular word embedding model, Word2vec, has traditionally been presented as a (shallow) neural network. By the end of this blog post series, you will have learned how to compute the same Word2vec embeddings with a single matrix factorization. In the following parts, we’ll show you how to simply compute that matrix factorization using a standard deep learning library: TensorFlow 2.0.

This series is based on my recent work pursued at Mila (Quebec AI Institute) in collaboration with Edward Newell (PhD), currently under review. You can also refer to my recently published Master’s Thesis for an in-depth engagement with the ideas presented here. It proceeds as follows:

Part 1: Introduction and motivation; check it out!

Part 2: Mathematically deriving Word2vec as a matrix factorization; check it out!

Part 3: In this post, we will present simple experiments to verify the correctness of our explicit matrix factorization formulation of Word2vec, and demonstrate how to implement matrix factorization in TensorFlow 2.0.

Summary

In this post we will provide a very simply matrix factorization implementation of SGNS (i.e., skip-gram with negative sampling, Word2vec) in Tensorflow 2.0. This method can produce a set of useful word embeddings in less than 30 minutes, and it does not require the original text corpus! This post is accompanied with an interactive Colab file for reference.

This is a unique implementation because it uses matrix factorization. Instead of requiring a large corpus of text, it only requires a much smaller file containing pre-extracted corpus statistics (as we discussed in Part 2). An implementation of such statistic extraction only ever needs to be run once (for a specific vocabulary and context window size), and a fast, parallelized implementation (e.g., my implementation in Go) can do so in less than 3 hours for a 10GB corpus of compressed .gz text files.

Step 1. Looking at the Corpus Statistics

Recall from last time that SGNS is really factorizing a matrix filled with the pointwise mutual information (PMI) statistics, shifted by a global constant, log(k), where k is the number of negative samples. In this post we use k=1, but this hyperparameter can easily be tuned and tested in practice.

Before implementing any model, it’s always useful to look at the data that confronts us. I extracted corpus statistics for the 5000 most frequent words found on Wikipedia with a context window size of w=5, downloaded from a 2018 Wikipedia dump (about 10GB compressed). Words that do not fall in that vocabulary are removed from the dataset prior to extraction. Note that this is a very small vocabulary size, which we are only using here for the purposes of simple exposition on limited Colab resources.

The PMI matrix for this dataset looks like this (sorted alphabetically/sorted by unigram frequency); the white space consists of negative infinities, -∞.

PMI matrix extracted for 5000 most frequent words on Wikipedia corpus, for a context window w=5. Left is sorted alphabetically, right is sorted by word frequency. Y-axis is the context word index, X-axis is the term word index (this matrix is symmetric due to our using a symmetric context window).

This is the matrix that we are attempting to factorize; i.e., the dot products between term vectors and context vectors will attempt to reproduce this matrix. Recall that the definition of PMI is log((N * Nij) / (Ni * Nj)), and that these numbers are simply the co-occurrence statistics extracted from the data. While our mathematical analysis from last time revealed that our loss function will attempt to reproduce this matrix, it is more efficient to store these statistics separately, as we show later on.

First, however, observe that, despite being a very small vocabulary of only 5000 words, 75% of the matrix is empty (i.e., has Nij=0)! That is, 75% of all possible word pairs (according to this vocabulary and the 5-word context window size) never occur in the dataset. As the vocabulary size increases, the proportion of the empty matrix approaches 90 and even 99%, which is what motivates sparse factorization algorithms like GloVe.

Some of the highest PMI(i,j) values we extracted. Because our context window is symmetric, PMI(i,j) = PMI(j,i).

PMI is a measure of association between two words: the value of PMI answers the question “how much does the presence of a word i suggest that the other word j will surround it?” Looking at the data, we observe the term-context pairs on the left as having the highest PMIs.

This is exactly what we would expect! Intuitively, if I see the word “puerto” I can be very sure that the word “rico” will be nearby.

As a final inspection, let’s look at the histogram of all PMIs (excluding the -∞).

Histogram of PMI values from the matrix above, excluding -∞.

This is another way to examine the distribution of target dot products contained in that matrix above. As we can see, PMI looks like it is normally distributed around 0. I conjecture that this is what makes it a suitable target objective for learning vector-covector dot products, since machine learning algorithms tend to love normal distributions!

Step 2. Implementing SGNS matrix factorization in TensorFlow 2.0

Now that we have our data and have an understanding of it, it’s time to put theory into practice. Recall that the PMI matrix is being implicitly factorized by the SGNS loss function. In practice, as we derived in Part 2, the loss has a very different form than a typical matrix factorization loss such as that of the SVD, especially since it doesn’t deal with the PMI matrix “directly”, but an algebraic decomposition of it.

Defining the Model

First, however, it would be useful to define a suitable Keras Model that performs the matrix multiplication between term vectors and context vectors (co-vectors). This is very simple in TensorFlow 2.0:

from tensorflow.keras import Model

class MFEmbedder(Model):
def __init__(self, init_func, vocab_size, dim):
super(MFEmbedder, self).__init__()
self.V = tf.Variable(init_func(vocab_size, dim),
name="vectors")
self.W = tf.Variable(init_func(vocab_size, dim),
name="covectors")

def call(self, x=None):
return tf.matmul(self.V, self.W, transpose_b=True)

Note that the model does not take any inputs, every time it is called it simply performs the matrix multiplication between vectors and covectors. This could alternatively be defined with inputs being a slice object that indicates a subset of vectors and covectors to multiply together; such a pattern is required in large vocabulary settings, where the full NxM matrix is too massive to store in memory all at once (in this case it would also be necessary to store Nij as a sparse matrix).

Also, we leave the init_func as an option for the user to define how to randomly initialize the embeddings; or, perhaps to initialize the embeddings with a different set of pre-trained ones. (One might also opt to use the Keras shortcut tf.add_weight rather than using Variable).

For example, I found that a scaled-down normal distribution works well in practice. After defining the initialization function (init_func), we construct our model. For the purpose of simple exposition, we use a small vocabulary of 5000 words and are learning 50-dimensional embeddings:

normal_init = lambda v, d: tf.random.normal((v, d), 0.0, 1.0/d)
model = MFEmbedder(normal_init, vocab_size=5000, dim=50)

Defining the Loss

Now we are ready to define the loss function, based on the pre-extracted corpus statistics. Recall that it is more simple if we retain all the corpus statistics that compose PMI, rather than holding the matrix in memory.

class SGNSLoss:
def __init__(self, Nij, Ni, Nj, N, k):
self.Nij = Nij # shape is V x V
self.Ni = Ni # shape is V
self.Nj = Nj # shape is V
self.N = N # constant
self.k = k # constant

def __call__(self, M_hat):
pos_samples = tf.reduce_sum(self.Nij *
tf.math.log_sigmoid(M_hat))
neg_samples = (self.k / self.N) * \
tf.reshape(self.Nj, (1, -1)) @ \
tf.math.log_sigmoid(-M_hat) @ \
tf.reshape(self.Ni, (-1, 1))
return -(pos_samples + neg_samples) / self.N

This is a direct implementation of the loss function as we defined in Part 2. Practically, we use reduce_sum on the positive samples to minimize GPU load, and as well we implicitly reduce_sum on the negative samples. The latter is done by using the following matrix multiplication: 1 x N @ N x N @ N x 1, which implicitly takes the outer product between Ni and Nj within the summation.

From a theoretical perspective, observe the difference between the positive and negative samples. The positive samples reflect the actual observations of corpus statistics; i.e., the real Nij values (self.Nij). On the other hand, the negative samples term reflects what we would expect to observe, assuming independence. This is because we are implicitly taking the outer product between the term and context frequencies, Ni and Nj, who together reflect the independence probability p(i) * p(j).

Defining the Training Loop

We are almost ready to train some word embeddings! All we need now is to define the training step, and the TensorFlow 2.0 API permits us to do this in a straightforward manner:

# Create the object representing the loss function.
loss_obj = SGNSLoss(Nij, Ni, Nj, N, k=1)
# Now write the train_step constructor function.
def
make_train_step():
optimizer = tf.keras.optimizers.Adam(lr=0.1)
train_loss = tf.keras.metrics.Mean(name='train_loss')

@tf.function
def train_step(scope_model):
with tf.GradientTape() as tape:
mhat = scope_model(None)
loss = tf.reduce_sum(loss_obj(mhat))
gradients = tape.gradient(loss, scope_model.trainable_variables)
optimizer.apply_gradients(zip(gradients,
scope_model.trainable_variables))
train_loss(loss)
return train_loss, train_step

Here, we are defining the TensorFlow 2.0 training function within the make_train_step function, which is essentially a function constructor. Note that, in the training step, the loss_obj is used to compute the SGNS loss on the model’s prediction, mhat.

We use the returned values of this constructor as follows:

results = []
n_iters = 500 + 1
print_every = 50
for i in range(n_iters):
train_step(model)
res = train_loss.result()
results.append(res)

if i % print_every == 0:
... # print stuff

For every iteration, we step the model and record the result, printing every so often. Let’s see what happens!

Step 3: Observing results over time

To assist in understanding how the model learns word associations, I included a printing step of the top 5 most similar (cosine similarity) words to the word “money” during training. We would expect that the model would converge toward semantically similar words to money, such as “cash”. Indeed, this is exactly what happens as time progresses:

step    0 - loss: 1.3862977027893066 (2.5257 seconds)
similar to "money": barry operates watched important facilities
step 50 - loss: 1.3117446899414062 (84.5514 seconds)
similar to "money": cash paying credit easier pay
step 100 - loss: 1.2981760501861572 (84.9554 seconds)
similar to "money": cash funds payments taxes benefits
...step 500 - loss: 1.2863645553588867 (84.6426 seconds)
similar to "money": cash funds payments pay benefits

We found that, between steps 300 and 500, the top 5 most similar words to “money” remained the same, showing rapid convergence.

MF-SGNS loss over time. Trains embeddings in less than 15 minutes!

We also observe a similar pattern of convergence in the loss over time — a very nice and stable learning curve! As well, this model trains very quickly (less than 15 minutes). It’s likely that training could be sped up by tuning the learning rate, and perhaps by using a dynamic one.

Let’s look at a few more word similarities yielded from the model (obtaining the top 5 most similar vectors to the query vector, using cosine similarity):

Query      Top-5 most similar vectorsdrive     : run turn catch push move
america : europe asia africa mexico canada
east : west north southeast south eastern
soviet : allied communist yugoslav russian revolutionary
belgium : denmark finland portugal austria hungary
brussels : vienna seoul cairo rome athens
1914 : 1915 1917 1939 1919 1912

Very reasonable semantic associations! For example, “1914” (the year World War 1 started) is most similar to the next year of WW1, but is also highly similar to “1939”: the year WW2 started)—very cool! We also observe that “brussels” is similar to other capitals, while “belgium” is similar to other countries; exactly what we would expect.

Do the dot products approximate PMI?

Now, throughout this entire blog series I have claimed that the SGNS vectors approximate the PMI matrix. Let’s look and see if this is actually true!

The original PMI matrix versus our model’s attempted approximation of it. Whitespace on the left is -∞. (Note that I re-coloured the left image from before to make pink be the centre at 0).

Visually, this doesn’t look particularly appealing —indeed, it seems like our model is approximating a lot of PMIs where there should be -∞! However, do we really want our vector dot products to equal -∞? If that happened, the vectors would not be very useful since they would have infinite norm, and thus be filled with NaNs.

It is much more intuitive and enlightening to observe the difference between the sigmoid of the PMI versus the sigmoid of the model approximation. This is because σ(-∞) = 0, a much more reasonable number than ∞! Moreover, it is the sigmoid function that is used to compare the matrices in the loss function. Observe the sigmoid-ed matrices, below:

Original PMI matrix versus our model’s approximation of it, after activated by the sigmoid function.

Here we can draw conclusions more easily. Namely, the model does approximate the PMI matrix in the regions where it PMI != -∞. However, when PMI = -∞, we see that the model actually introduces new associations. This may actually be a benefit, because perhaps just because words don’t appear together in a corpus does not mean that they won’t be semantically related to each other.

Additionally, we should note that it is actually impossible to completely reconstruct the original PMI matrix using our embeddings. This is because our model is a low-rank approximation of the original. While the rank of the PMI matrix is 5000, we are attempting to approximate it with a much much more simple rank 50 matrix (since Mhat is the product of two 5000 x 50 matrices). Therefore, it is not surprising that the approximation looks like it has repetitive square patterns in it — it’s mathematically impossible for it do otherwise!

Final Remarks

Improving the matrix factorization embeddings

The implementation we present of the matrix factorization Word2vec embedding algorithm is in its skeletal form. Several improvements would be necessary in order to implement it in a practical setting. Indeed, some small but important features introduced by Mikolov et al. are not currently included in the Colab code.

This includes the following algorithmic design decisions, discussed in detail by Levy & Goldberg:

  • Frequent word subsampling. Mikolov et al. found that performance, and running-time, could improve when they undersampled frequent words from the corpus, according to the words’ frequencies. E.g., rather than having “the” occur 300 million times, perhaps it’s only necessary to observe it 100 million times. With some clever reasoning, this can be implemented within the matrix factorization framework without having to re-extract corpus statistics. However, this is less of a problem for matrix factorization, since we update each vector an equal number of times (yet the co-occurrence statistics will still be altered when
  • Context distribution smoothing. Raising all Nj counts to the power 0.75 (rather than their original values) smooths out the distribution of context words, as discussed by Levy & Goldberg (2015). Doing this can substantially improve performance, and can easily be implemented within the matrix factorization framework by raising Nj ** 0.75 and re-computing the independence probabilities in neg_samples accordingly.

Additionally, there are implementation design decisions that must be taken into account:

  • Sharding for MF. Using sharding, as in Swivel (Shazeer et al. 2016) can be done to allow for large vocabularies without overloading the GPU. In practice, I’ve found that a 12GB GPU can compute shards that are 12,500 x 12,500 in size, and that embeddings for a 50,000 word vocabulary can be trained in less than 3 hours — very useful for testing many different hyperparameters.
  • Sparsely representing Nij. One of the most important decisions for a practical implementation is to represent the Nij matrix of corpus statistics using a sparse representation. Recall that 75% of the matrix were 0s (at bigger vocabularies this goes to 90 and 99%) — it is not very space efficient to represent this matrix with a normal dense matrix.
  • Need a big vocabulary? Think about strategies to sample term and context indices without needing to do a full matrix multiplication. GloVe (Pennington et al. 2014) offers one way — only sample elements with Nij>0, and do this one at a time; but this misses out on a lot of the matrix, and we can learn a lot “by noticing what’s missing” (Shazeer et al. 2016).

Conclusion

I hope you readers have enjoyed reading these blog posts, and that perhaps you’ve learned something that inspires your research or product implementations. If you plan on using anything from these ideas in your work, please do it (and consider citing me)! I’d also love if you’d share your ideas with me in personal communication (kiankd@gmail.com), as I find this work to be very interesting and personally fulfilling.

Finally, stay tuned for mine and Newell’s paper on this topic, to be published hopefully within the next 2–3 months! As well, my Master’s Thesis is on this topic, and should also be published very soon! So, make sure to check those out and cite those in your work once they are available.

References

Mikolov et al. 2013. Distributed representations of words and phrases and their compositionality. Neural Information Processing Systems.

Levy & Goldberg 2015. Improving Distributional Similarity with Lessons Learned from Word Embeddings. Transactions of the Association for Computational Linguistics.

Shazeer et al. 2016. Swivel: Improving Embeddings by Noticing What’s Missing. Arxiv.

--

--

Kian Kenyon-Dean
Radix
Writer for

NLP & DL; Master’s in Computer Science from McGill/Mila; AI Developer @ Bank of Montreal’s AI Capabilities Team