Introducing DistMult and ComplEx for PyTorch Geometric

Learn how to leverage PyG’s newest knowledge graph embedding tools!

David Kuo
Stanford CS224W GraphML Tutorials
15 min readMay 15, 2023

--

Written by David Kuo and Riya Sinha for the CS224W final project

Introduction

Knowledge graph embeddings are a powerful tool for analyzing knowledge graphs, and now with our implementation of DistMult and ComplEx, they can be easily integrated into any project! In this post, we’ll walk you through the intuition behind DistMult and ComplEx, and show you how to use PyG to implement these techniques, and explore some practical applications of knowledge graph embeddings in graph machine learning. So, get ready to take your graph analysis skills to the next level with DistMult and ComplEx in PyG!

Knowledge Graph Primer

Knowledge graphs are a structured way to capture relationships between different entities using sets of ordered triples in the form (head, relation, tail). For example, above, we can see that Harry Potter is a wizard (head = Harry Potter, relation = occupation, tail = wizard) and that he belongs to House Gryffindor (head = Harry Potter, relation = House, tail = Gryffindor).

Knowledge graph embeddings (KGEs) are models that transform entities and relationships to low-dimensional representations as vectors using shallow embeddings, designed so that relationships between them are preserved. They do this by assigning a score to each (head, relation, tail) triple based on its correctness. In the case of the knowledge graph above, an effective knowledge graph embedding would give a high score to (Harry Potter, Friend, Hermione Granger) and a low score to (Harry Potter, Friend, Draco Malfoy). The secret sauce for each knowledge graph embedding scheme is in how it scores a given triple, its scoring function!

Each knowledge graph embedding scheme (i.e. each scoring function) has its strengths and weaknesses, and so it’s important to consider the needs of your dataset and task when selecting which to use. For example, many embedding schemes can only model certain types of relationships, like symmetry, antisymmetry, composition, one to many, and many to one relationships. Another factor is the number of parameters a particular knowledge graph embedding scheme requires which can affect how long your model takes to train.

For a more detailed introduction to knowledge graph embeddings and understanding the properties of relations, we recommend the article Simple Schemes for Knowledge Graph Embedding.

DistMult

The DistMult algorithm by Yang et al. embeds the head, relation, and tail all as vectors in the same real dimensional space. We can see this in the implementation __init__ function of the superclass KGEModel, where it creates embeddings in the dimension specified by hidden_channels:

class KGEModel(torch.nn.Module):

def __init__(
self,
num_nodes: int,
num_relations: int,
hidden_channels: int,
sparse: bool = False,
):
super().__init__()

self.num_nodes = num_nodes
self.num_relations = num_relations
self.hidden_channels = hidden_channels

self.node_emb = Embedding(num_nodes, hidden_channels, sparse=sparse)
self.rel_emb = Embedding(num_relations, hidden_channels, sparse=sparse

DistMult defines it’s scoring function as the generalized dot product of these three vectors:

So what’s going on here? We can actually think of the relation vector as a diagonal matrix! The matrix multiplication of the head and relation then produces another vector, where each component of the resulting vector is the respective component from the head entity, scaled by the respective component from the relation. This allows the model to emphasize or de-emphasize parts of the vector representation based on the relation.

To compare this transformed head entity representation to the tail entity representation, we take the dot product, which gives a scalar score that can be interpreted as the cosine similarity between the transformed head and the real tail.

This translates to the forward method of the model, which takes in a set of triples and calculates the score for each according to the scoring function:

# DistMult forward

def forward(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:

head = self.node_emb(head_index)
rel = self.rel_emb(rel_type)
tail = self.node_emb(tail_index)

return (head * rel * tail).sum(dim=-1)

Now, given this scoring function, how can a DistMult model train the entity and relation embeddings to represent the known relations while penalizing high scores for relations that do not exist? The paper suggests using a margin ranking loss:

Here, e’ represents a corrupted head or tail. Essentially, by changing either the head or tail entity of a known relation triple to another entity, we can generate a plausible negative sample. A good embedding scheme will ideally try to make sure that known positive relation triples receive high scores, while the negative samples receive low scores. In the above function, S simply represents the scoring method we saw above, and so we want to maximize the difference between the negative and positive sample scores. However, to prevent our model from learning trivial embeddings where it continues decreasing the loss by pushing samples far apart, we stop rewarding it once the loss reaches 0. The “margin” comes in through the “+ 1” term, where it enables the maximum loss of 0 to be achieved only once the positive sample score is higher than the negative sample score by 1.

Our implementation of this loss function takes in a batch of positive triples and calculates the score for them by running them through the forward function we defined above. It then performs the head or tail corruption to generate a sample of random negative triples, and scores those as well. The loss is then calculated using PyTorch’s margin_ranking_loss function. Notably, while the paper specifies 1 as the margin, we have made the margin a tuneable hyperparameter of the model that you can modify!

# DistMult loss

def loss(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:

pos_score = self(head_index, rel_type, tail_index)
neg_score = self(*self.random_sample(head_index, rel_type, tail_index))

return F.margin_ranking_loss(
pos_score,
neg_score,
target=torch.ones_like(pos_score),
margin=self.margin,
)

So when is DistMult a good choice to use?

Let’s first consider the relations DistMult can handle:

  • ✅ Symmetry: It can handle symmetric relations, since our scoring function’s multiplication is independent of the head and tail entity order
  • ✅ 1-to-N/N-to-1: It can also handle 1-to-N / N-to-1 relationships, since multiple vectors can exist that have the same dot product score with relation applied to the head or tail
  • Composition: While simple compositions can be thought of as a series of diagonal matrix multiplications, these multiplications can only affect individual components independently along the diagonal and so in practice is unsuitable for modeling composition of complex graphs
  • Antisymmetry: Due to the commutative property of multiplication, DistMult’s scoring will always give the same score to (h,r,t) and (t,r,h), so these cannot be differentiated in score
  • Inverse: Once again, any relation will have the same score both ways due to the scoring function’s multiplication, and thus, both relations will actually necessarily be symmetric relations. Further for “inverse” relations r1 and r2, (h, r1, t) = (t, r1, h), and (h, r2, t) = (t, r2, h), so r1 and r2 will actually be the same relation!

We can therefore see that DistMult is quite limited in what it can express. A major advantage, though, is that it only requires O(K) parameters for both the node and relation embeddings, where K is the embedding dimension! This makes DistMult very fast to train, and can be a good choice if your relationships fall within its scope. Additionally, its simplicity helps enable interpretability.

But let’s say you need more expressiveness⁠— That’s where ComplEx comes in!

ComplEx

The ComplEx knowledge graph embedding by Trouillon et al. 2016 builds upon DistMult by leveraging the complex vector space for embedding entities and relations. In practice, we represent the real and imaginary components of these vectors separately.

The ComplEx model gets its real embeddings from its parent KGEModel class, like DistMult, but it also declares another set of embeddings with the same hidden dimension to represent the imaginary components:

class ComplEx(KGEModel):

def __init__(
self,
num_nodes: int,
num_relations: int,
hidden_channels: int,
sparse: bool = False,
):
super().__init__(num_nodes, num_relations, hidden_channels, sparse)

self.node_emb_im = Embedding(num_nodes, hidden_channels, sparse=sparse)
self.rel_emb_im = Embedding(num_relations, hidden_channels,
sparse=sparse)

self.reset_parameters()

Now before we get to the actual scoring function, it may be helpful to refresh your understanding of operations on complex numbers and vectors. Like the dot product in real spaces, we want to define the inner product operation of two complex vectors. One of the main properties we want to maintain is that the inner product of a complex vector with itself can only be 0 if and only if the vector is zero. When we take the regular bilinear dot product of a complex vector with itself, we get:

Unfortunately, there are many cases where this value can be 0 without the vector being 0. For example:

So instead, it is defined using the complex conjugate for one of the terms. It now becomes:

Notice that since all terms in the scalar expression are squared, we can only receive a value of 0 if and only if the vector itself is 0! With this property satisfied, we can extend this definition to any two complex vectors now, to define the Hermitian or sesquilinear product:

Phew! And now, we have the tools to understand the ComplEx scoring function:

Like DistMult, ComplEx uses the generalized inner product between the three complex vectors, but only the tail vector is conjugated. We also only take the real part of the resulting complex vector ! This can be seen as a final projection onto the real plane to get a real-valued scalar. As we’ll see, this projection is what enables much of the expressivity of ComplEx.

Take special note of the final form of the scoring function, however. It turns out that expanding the multiplication like the examples above actually enables us to simplify the equation to the real generalized dot product of the component vectors of the complex entity and relationship embeddings!

In code, since we already initialized the real and imaginary vector components as separate embeddings, the score function actually becomes simple arithmetic:

# ComplEx forward

def forward(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:

head_re = self.node_emb(head_index)
head_im = self.node_emb_im(head_index)
rel_re = self.rel_emb(rel_type)
rel_im = self.rel_emb_im(rel_type)
tail_re = self.node_emb(tail_index)
tail_im = self.node_emb_im(tail_index)

return (triple_dot(head_re, rel_re, tail_re) +
triple_dot(head_im, rel_re, tail_im) +
triple_dot(head_re, rel_im, tail_im) -
triple_dot(head_im, rel_im, tail_re))

def triple_dot(x: Tensor, y: Tensor, z: Tensor) -> Tensor:
return (x * y * z).sum(dim=-1)

We have a way to assess how likely a given relation triple is now, so we need to again define a loss function to train our model’s embeddings, such that it penalizes negative samples from having scores greater than or equal to known positive relation triples. The ComplEx paper suggests using the negative log likelihood with a sigmoid link function, and L2 regularization:

Here, Y represents the class of the sample, which accounts for both our positive and negative samples in this equation.

To implement this, as the model’s loss function, we generate positive and negative scores in the same fashion as DistMult but instead use the binary cross entropy loss using the positive and negative scores and their target labels, which is mathematically equivalent for binary classification of whether a graph relation exists or not. No margin hyperparameter is required with this function, either. The calculated loss can then be back propagated to update the head, relation, and tail embeddings via gradient descent.

# ComplEx loss

def loss(
self,
head_index: Tensor,
rel_type: Tensor,
tail_index: Tensor,
) -> Tensor:

pos_score = self(head_index, rel_type, tail_index)
neg_score = self(*self.random_sample(head_index, rel_type, tail_index))
scores = torch.cat([pos_score, neg_score], dim=0)

pos_target = torch.ones_like(pos_score)
neg_target = torch.zeros_like(neg_score)
target = torch.cat([pos_target, neg_target], dim=0)

return F.binary_cross_entropy_with_logits(scores, target)

You might be thinking “wait a second, there’s no regularization term here!”, and you’d be correct. The L2 regularization term is actually handled by users in training already, using the weight_decay parameter on their optimizers. Definitely something to keep in mind when using ComplEx!

So when should you use ComplEx? Let’s again first look at the relationships ComplEx can handle:

  • ✅ Symmetry: Like DistMult, ComplEx can handle symmetry, by allowing the Im(r) = 0 which, as we can see from our final scoring function, cancels out the last two terms such that order of the head and tail entity no longer matters and will give the same score
  • ✅ 1-to-N/N-to-1: Like DistMult, ComplEx can handle symmetry, since we can have multiple vectors that project the same way for the final score.
  • Composition: Like DistMult, ComplEx uses diagonal matrix multiplications which can only affect individual components independently along the diagonal and so in practice is unsuitable for modeling composition of complex graphs.
  • ✅ Antisymmetry: ComplEx can now model anti-symmetry– if you look at the final form of the scoring function, one term is subtracted from the others, and so switching the head and tail entities can give a different score if Im(r) is not 0!
  • ✅ Inverse: ComplEx can now model inverse relationships by setting the second relationship to be the conjugate of the other, so that the real projections seen by the score function are still the same!

ComplEx is therefore a more flexible embedding scheme than DistMult, with more relationship types it can model. However, we have twice the number of parameters compared to DistMult, since we had to create a second set of embeddings to represent the complex space, which makes ComplEx slower to train.

Usage In Practice

Now that you know DistMult and ComplEx inside and out, let’s see them in action with an example using the Freebase FB15k dataset! The full code for the example is available here but we’ll walk through the key pieces of it together step by step below focusing just on ComplEx.

1. Setup

We import our dependencies and load our training, validation, and test datasets from PyTorch Geometric to our training device (preferably a GPU with CUDA, but a CPU works great too!). torch.geometric.datasets makes this as easy as importing our dataset and selecting our train, val, or test splits.

import os.path as osp

import torch
import torch.optim as optim

from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx

device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FB15k')

train_data = FB15k_237(path, split='train')[0].to(device)
val_data = FB15k_237(path, split='val')[0].to(device)
test_data = FB15k_237(path, split='test')[0].to(device)

2. Model

Aside: This step shows how PyTorch Geometric makes doing graph machine learning easy by providing all sorts of graph machine learning models for us out of the box!

We instantiate our ComplEx model with the number of nodes and number of edge types in our dataset. Conveniently these are provided for us in train_data.num_nodes and train_data.num_edge_types (this will be true for any dataset from PyTorch Geometric!).

All that’s left is to choose how many hidden channels we want our knowledge graph embeddings to have. More hidden channels lets our knowledge graph embeddings capture more information about the knowledge graph entities, but means our model is also more computationally expensive to train.

model = ComplEx(
num_nodes=train_data.num_nodes,
num_relations=train_data.num_edge_types,
hidden_channels=50,
).to(device)

3. Data Loader

We instantiate our data loader that passes data to our model one batch at a time for learning. As before, the head_index, rel_type, and tail_index for our dataloader are conveniently provided for us in train_data.edge_index[0], train_data.edge_type, and train_data.edge_index[1] respectively.

Aside: train_data.edge_index is a table with the first (index 0) column storing the index of head nodes and the second (index 1) column storing the index of tail nodes.

All that’s left for us to decide is on our batch size, that is, how many examples to feed our model at a time, and whether to shuffle our data as we pass it to our model.

Generally speaking, it’s a good idea to shuffle our data during training, and to give our model as much data as our GPU or CPU can handle. You will have to experiment a bit to determine what the maximum batch size is for your device. Can your GPU or CPU handle a batch size of 1000?

loader = model.loader(
head_index=train_data.edge_index[0],
rel_type=train_data.edge_type,
tail_index=train_data.edge_index[1],
batch_size=1000,
shuffle=True,
)

4. Optimizer

We instantiate our optimizer which improves (i.e. optimizes) our model parameters (passed in as model.parameters()) based on the data that we give it. Adagrad is a very effective and popular optimizer for machine learning and good default option to try along with Adam.

Aside: You can easily try other optimizers with torch.optim!

The learning rate tells our optimizer how much it should update our model based on what it’s learned from each batch of data.

Weight decay, also known as L2 regularization, penalizes the model weights from getting too large — its job is to make sure that our final model can generalize to new and unseen data rather than simply memorize the training data.

Learning rate and weight decay are important hyperparameters that can significantly affect the quality of our trained model and choosing them can be a bit of a dark art and generally requires a lot of trial and error. Here, we have provided you with a learning rate and weight decay that we’ve found to work well for ComplEx:

optimizer = optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6)

5. Training Loop

We now write our training loop encapsulated as a function train(). Because certain machine learning modules like Dropout have different functions during training vs. testing, always remember to set your model into training mode with model.train() before training!

Always remember to set your model into training mode with model.train() before training!

We will iterate through our training dataset in batches using the data loader we instantiated in Part 3. For each batch, we

  1. Zero out our gradients to give ourselves a clean slate for learning from this new data
  2. Calculate our loss based on how well our model scored each (head, relation, tail) triple in the batch
  3. Backpropagate that loss through the model to generate new gradients that tell our optimizer how to update our model based on what we’ve learned from this batch of data
  4. Make those updates to our model with optimizer.step().
  5. We also keep track of our total loss across the full training dataset with total_loss and return the average loss, total_loss / total_examples.
def train():
model.train()
total_loss = total_examples = 0
for head_index, rel_type, tail_index in loader:
optimizer.zero_grad()
loss = model.loss(head_index, rel_type, tail_index)
loss.backward()
optimizer.step()
total_loss += float(loss) * head_index.numel()
total_examples += head_index.numel()
return total_loss / total_examples

6. Test Function

We’re almost there! We now write our test function that tells us how our model is performing as it learns on the training data. We use the decorator @torch.no_grad() to tell PyTorch Geometric not to track any gradients since our purpose in this function is to test and not to learn from the validation dataset (that would be like peeking at the test questions and answers before a test!). As with model training, always remember to set your model to evaluation mode with model.eval() before testing!

Always remember to set your model to testing or evaluation mode with model.eval() before testing!

PyTorch Geometric has conveniently provided us with a model.test() in the torch_geometric.nn.kge that we pass our head_index, rel_type, tail_index, and batch_size into just like we did with our data loader in Part 3.

We also specify a new parameter k for the Hits@k score that counts many positive triples are ranked in the top k positions vs. negative triples (the more the better!). Here, we choose k=10 but often times we will report Hits@k over a range of values of k to gain a more holistic view of our model’s performance.

@torch.no_grad()
def test(data):
model.eval()
return model.test(
head_index=data.edge_index[0],
rel_type=data.edge_type,
tail_index=data.edge_index[1],
batch_size=20000,
k=10,
)

7. Train!

Let’s put all of this code into action! One epoch is defined as one pass through the training dataset. Here we tell our device to train our model on the training dataset 500 times, printing our average loss over the training dataset each time and testing our model on the validation dataset every 25 epochs. At the end of training, we print out our final results! What do yours look like?

for epoch in range(1, 501):
loss = train()
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
if epoch % 25 == 0:
rank, hits = test(val_data)
print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '
f'Val Hits@10: {hits:.4f}')

rank, hits_at_10 = test(test_data)
print(f'Test Mean Rank: {rank:.2f}, Test Hits@10: {hits_at_10:.4f}')

Great job! You’ve just learned a whole lot about knowledge graph embeddings, DistMult, ComplEx, and even trained your first knowledge graph embeddings!

Finale

If you’re still craving some more ComplEx-ity or need just a little more knowledge graph to complete your life, head over to the official PyTorch Geometric documentation for torch_geometric_nn.kge.DistMult and torch_geometric.nn.kge.ComplExor check out Stanford CS224W Graph ML Tutorials for more great content about PyTorch Geometric and graph machine learning!

--

--