Using fastText and to classify relationships in Knowledge Graphs

TLDR: In this post, we will examine how a simple model, fastText, learns to represent entities in a subset of the FB15K knowledge graph, by classifying the relationship between pairs of entities in the graph.

By: Dhruv Nair

An increasing number of machine learning solutions, and companies are leveraging knowledge graph data, to tackle industries that require deep domain expertise. In fact, knowledge graphs underpin the natural language capabilities of Alexa, Siri, Cortana and Google Now. Our users at are exploring applications, such as; semantic search, intelligent chatbots, advanced drug research and dynamic risk analysis.

In this post we will try to provide an introduction to knowledge graphs and walkthrough a simple model developed at Facebook, that performs surprisingly well at knowledge base completion tasks.

Want to go straight to the good stuff 😎?
Feel free to skip the background section below on Knowledge Graphs and go straight to the sections on 1) Data, 2) fastText model, and 3) Interpreting the results
Here’s the data for this model. See the full code + results with a public project! 📓

Background: What are Knowledge Graphs?

Knowledge Graphs are highly flexible data structures that are used to represent the relationships between real-world entities. Entities such as “Mona Lisa”, or “DaVinci”, are represented as nodes in the graph, while relationships such as “created_by”, are represented as edges.

Informal graph of the sample triples from the Resource Description Framework

These graphs are a way to formally structure domain-specific knowledge and formed the basis for some of the earliest artificial intelligence systems. Google, Facebook, and LinkedIn are a few of the companies that leverage knowledge graphs to power their search and information retrieval tools.

There are many ways to think about the data in these graphs. One approach is to follow the Resource Description Framework standard, and represent facts in the form of subject, predicate, object triples (S,P,O), along with a binary score indicating whether the triple is valid or not.

An example of the standard framework using some example subjects and objects

Let us say that

represents the set of all entities in the graph, and

represents the set of all relations in the graph. Any triple,

can be modeled as a binary random variable, and all possible triples in the graph can be grouped in a 3D array ¹

Tensor representation of binary relational data ¹

The matrix Y can be quite huge for most knowledge graphs. However, only a few of the possible relations in the graph turn out to be true, and the matrix ends up being quite sparse.

For example, in this post, we only consider the first 60 relationships in the Freebase knowledge graph, and what we observe is that each entity pair has an average of one valid relationship in the training set.

Latent Feature Models ¹ are one way to model the interaction between entities in the graph matrix. Without going into too much detail; the key intuition behind these models is that the relationships between entities in the graph matrix can be modeled through the interaction of latent features that describe each entity. We call these features “latent” because they are not directly observed in the data, but rather, they are learned through training ¹

Typically, the objective of latent feature methods is to organize symbolic objects (e.g., words, entities, concepts) in a way such that their similarity in the latent space reflects their semantic or functional similarity ²

Socher et al.(2013)

In this post we will examine how a simple model, fastText, learns to represent entities in a subset of the FB15K knowledge graph, by trying to classify the relationship between entity pairs.

Want to dig into Knowledge Graphs even more? Check out “WTF is a knowledge graph?” by Jo Stichbury on the topic:

What is fastText and how does it work?

fastText frames the knowledge-base completion task as a classification problem. For a given entity pair in the knowledge graph, fastText will average the vector representations of the tokenized entities, and feed this representation into a linear classifier that computes the probability distribution over a set of relationship classes.

For example, our fastText model classifies the entity pair (Takeshi Kitano, Japan) with the relationship of (/people/person/nationality).
**Note: Takeshi Kitano is a Japanese screenwriter and comedian. You learn something new every day!

What can I use it for?

fastText can serve as a very good baseline for a wide range of tasks ranging from sentiment analysis in documents, spam detection, and in our case, knowledge base completion.

Here’s the data for this model. See the full code + results with a public project! 📓

1. Dataset

We’re using the Freebase Knowledge Graph(FB15K). When you look at the raw data, you’ll see three columns: the first two columns, are the Freebase MID’s for the entities in the graph, while the third describes their relationship.

First few rows from the Freebase dataset after indexing.

As part of our data preparation for our fastText model, we’ve tokenized the different entities so that they have a unique integer ID. These IDs will later correspond to the row that represents our entity in our embedding matrix.

For the full code, see the public project
For more on embeddings, check out:

2. fastText Model

Again, we will be using fastText to learn our embeddings. While it is a simple model, with only a single hidden layer, fastText manages to deliver comparable performance to more complex models, and takes significantly less time to train.³

Onto the actual model!

In this snippet, we initialize the embedding layer with the number of entities we’re considering from the subset of our graph. So our input_dim will equal the number of unique entities in our graph dataset. Our output_dim, is the dimension of our embedding vector. This is a hyperparameter that we can optimize. Joulin et. al consider the following embedding sizes [10, 25, 50, 100, 150, 200] ⁴

In this post, we’ve considered embedding dimensions of size 25, and 200. In addition to the embedding dimension, we can also tweak the learning rate (0.001 or 0.0001), and batch size (128 or 256)

After averaging the vector representations of our entities, we feed this representation into a fully connected layer, with the same number of neurons as the relationships under consideration (in our case, 60). We then apply a sigmoid activation function to each neuron, which effectively turns them into a one vs all classifier for a given class.

main_input = Input(shape=(2,), dtype='int32', name='main_input')
embedding = Embedding(name="embedding",
input_dim=n_entities, # Number of entities
minval=-0.05, maxval=0.05, seed=None))
x = embedding(main_input)
x = GlobalAveragePooling1D()(x)
x = Dense(n_relationships)(x) # Number of relationships
output = Activation('sigmoid')(x)
model = Model(inputs=main_input, outputs=output)
For the full code as a script, see the public project

3. Analyzing the model results with

Quick start with just two lines of code for

We use to track our hyperparameters and evaluation metrics. We report our models as experiments so that we can see their performance and compare different iterations. We can use to organize and sort by these different hyperparameters (e.g. batch size).

Below you can see the six iterations we ran and reported with The model that had the top performance (highest AUC score and lowest loss) was fastText 12 (top row in the screenshot), with an AUC score of 0.89311097 and a loss value of 0.07944. This model was trained using a learning rate of 0.001, an embedding dimension of 25 and batch size of 256.

Note: We only considered 60 of the possible 1,435 relations in this dataset.
As a test: try improving fastText 12 with a higher embedding dimension (keeping other variables constant) of 200.
What's the AUC score for fastText 12.1?  🤔
The six fastText model iterations we ran with key hyperparameters and metrics such as AUC score and loss.

Since we are dealing with dataset with multiple target labels, metrics like precision or recall are not sufficient to describe the performance of the classifier. Accuracy is also not the best metric to look at for our fastText model because it doesn’t give you a sense of how your data is actually distributed (doesn’t account for the class imbalance in the data).

Instead of accuracy, some better metrics to track are:

  • Log Loss (eval_loss)
  • AUC score
In you can display hyperparameters and any other metrics you set when you report the experiment as columns on your experiment table.

Here’s an example prediction from fastText12:

We also use to compare fastText 12 and fastText 3 where the only difference is the learning rate. It’s possible that fastText 3 latched onto some sub-optimal minima. When we ran the same configuration again (fastText 13), the new AUC score was slightly better at 0.5926, but fastText 12 still blew it out of the water.

See the direct comparison here between fastText 12 and fastText 3. You can do code diffs (in the screenshot below) as well as chart, metrics, and hyperparameter comparisons within

There are several ways we could continue iterating and (potentially) improving our model:

  • get more data (consider the full set of relationship values)
  • change the loss function

We hope you enjoyed building this fastText model with us and getting a view into the power of! 🚀 🚀 🚀


Dataset from:

[1] Nickel, Maximilian, et al. “A review of relational machine learning for knowledge graphs.” Proceedings of the IEEE 104.1 (2016): 11–33.

[2] Nickel, Maximillian, and Douwe Kiela. “Poincaré embeddings for learning hierarchical representations.” Advances in Neural Information Processing Systems. 2017.

[3] Joulin, Armand, et al. “Bag of tricks for efficient text classification.” arXiv preprint arXiv:1607.01759 (2016).

[4] Joulin, Armand, et al. “Fast Linear Model for Knowledge Graph Embeddings.” arXiv preprint arXiv:1710.10881 (2017).

Dhruv Nair is a Data Scientist on the team. Before joining, he worked as a Research Engineer in the Physical Analytics team at the IBM T.J. Watson Lab.

About — is doing for ML what Github did for code. Our lightweight SDK enable data science teams to automatically track their datasets, code changes, experimentation history. This way, data scientists can easily reproduce their models and collaborate on model iteration amongst their team!

Like what you read? Give Cecelia Shao a round of applause.

From a quick cheer to a standing ovation, clap to show how much you enjoyed this story.