Augmenting Your Notes Using Graph Neural Networks

Arjun Karanam
Stanford CS224W GraphML Tutorials
16 min readMay 15, 2023

CS224w Course Project by Arjun Karanam and Michael Elabd

As students, we’re extremely familiar with taking notes. Go to class, write down anything relevant the professor might say, and then review the notes afterward — maybe once or twice for the more diligent, and right before the exam for the rest of us. However, this approach to note-taking assumes that everything we may want to know about the topic is contained in those notes and that our curiosity ends there. But that’s often not the case!

On numerous occasions, I’ve been taking notes and I’ve wondered — where can I read more about this topic? Let’s say I just went to a lecture on Hegel’s Philosophy of History, and want to see what might be relevant. I could just type “Hegel’s Philosophy of History” into Google, but that just returns a list of places I can buy it from and the Wikipedia article for the theory itself. Hmm, not quite what we want. The Wikipedia article is interesting though! Let’s click on it, and see if that’ll help us.

A Google search for “Hegel’s Philosophy of History”

On the Wikipedia page for Hegel’s Philosophy of History, we’re met with a brief description of the work, and some discussion on themes that are presented. Theoretically, we’ve already encountered this information from our original lecture. What’s more interesting are the links that point to other Wikipedia pages. By virtue of these links appearing on the page for Hegel’s Philosophy of History, we know that they must be related to Hegel’s Philosophy of History somehow. But which ones are more important? And what about topics that might be relevant, but aren’t mentioned on this page?

Wikipedia page for Hegel’s Philosophy of History

We could dive down a Wikipedia rabbithole (i.e, continuously clicking on links) until our curiosity is satisfied, but what if our note doesn’t have an associated Wikipedia page? Where do we start then?

Enter — Graph Neural Networks 💫

We’ve identified that given a page of notes, we’d like to see which Wikipedia articles are most relevant to that page of notes. How do we do that? Well, one great thing about Wikipedia is that it’s one big graph. A really big graph. And what area of Machine Learning is best suited for tasks on graphs? Graph Neural Networks (or GNNs for short)! And what if we frame our task as a link prediction task (where a link prediction task is one where a GNN needs to predict “missing” links in an otherwise complete graph)?

To be specific, we asked the following question: Let’s say we train a Graph Neural Network over all of Wikipedia (is that even possible?), and then train it to predict missing links between existing Wikipedia articles. Can we then give it our note (as if it was a brand new Wikipedia article), and ask it which existing Wikipedia articles are most likely to share an edge with this new article?

The Dataset

There are two pieces of data we need to get started: a few pages of a student’s (i.e our) notes, and Wikipedia represented as a graph. As you can imagine, one of these is a bit more difficult to obtain than the other.

Student’s Notes

We started with the easier task — extracting some of our notes to use at the very end of our project. We took some notes on articles from our Political Science class on “Comparative Democratic Development” and used OCR to turn them from handwritten notes to typed ones. The notes can be found in the Google Drive folder below.

Wikipedia Dataset

The more challenging task was extracting Wikipedia as a graph. Specifically, we needed the following:

Nodes: A list of all nodes (labeled with their titles) in the Wikipedia dataset.

Edges: A list of all the edges between the nodes, represented as (start_node, end_node) pairs

Page Content: A Dictionary of all nodes, and their associated page content (i.e {Hegel’s Philosophy of History: [all the words on the associated Wikipedia page]}

We first looked for existing datasets, but none of them matched exactly what we were looking for. Datasets like the WikiCS dataset came the closest but were limited to just a small scope of topics (Computer Science in the case of WikiCS, and Chameleons, Crocodiles, and Squirrels for the SNAP Wikipedia dataset). So, we had to search elsewhere.

The next option was using Wikipedia dumps. Every week or so, Wikipedia creates what’s called a “Dump” — a giant folder of the current content in Wikipedia. The problem is, almost no processing is done to it, and thus still has HTML tags and various implementation tags (denoting tables, lists, and so on) still embedded inside. Oh, and the dump itself is 83 gigabytes. While this would’ve been an interesting route to go down, after spending a few days attempting to process the dataset, we realized that processing the dump itself would probably have taken weeks, so we left this option as a worst-case scenario.

Finally, we struck a potential gold mine. There is a Wikipedia package for Python that allows you to access a page by its title, and the API returns the titles of all the pages it’s connected to, and the content of that page. Perfect! That’s exactly what we need. But one problem — you could only access one page at a time…

import wikipedia as wp

#acesses the Political Philosophy page
ds = wp.page("political philosophy")

#Once we have the page, we can fetch the links and the content
ds.links[:5]
print(ds.summary())

Thus began our process to create a Wikipedia graph, one page at a time. We revisited our classic Introduction to Data Structures knowledge and built a simple BFS to traverse all of Wikipedia. We modified the BFS algorithm to prioritize nodes that had a lot of outlinks, so we would always be traversing the densest parts of the graphs. However, quickly into our testing, we ran into an issue.

As you’d expect, the Wikipedia API has a rate limit, to prevent DDoS attacks. Makes sense. However, that made our job of querying the whole Wikipedia dataset that much harder. After lots of trial and error, we found that a randomized wait time of 5–10 seconds between each API call allowed us to avoid the rate limit. Unfortunately, that slowed us down significantly, and made the task of querying all of Wikipedia impossible in our time constraints.

Wikipedia English has roughly 6,000,000 pages. Given an expected 7.5 seconds between each page call, extracting all of Wikipedia through this method would take 520 days (assuming just a single process). Yikes.

Wikipedia — Problem Solving

So, we had to take two steps to solve this. First, we wrote code such that the process could be parallelized. Once we’ve accessed a node and gotten its contents/links, we write that to a file. Before we post an API call to a page, we check to see whether we or another computer, has accessed that page. If so, we skip over it. This allowed us to run multiple processes at once! But still, given our compute resources, this would take too much time. So, we decided that for our project, as it is a proof of concept, we’d pick just a subsection of Wikipedia. We selected a size of 20,000 nodes (40 hours on a single process, 20 hours on two processes) — still twice as large as the aforementioned WikiCS dataset.

class RelationshipGenerator():
"""Generates relationships between terms, based on wikipedia links"""
def __init__(self, save_dir):
self.links = [] # [start, end]
self.features = {} #{page: page_content}
self.page_links = {}
self.save_dir = save_dir
with open(os.path.join(self.save_dir, "features.json"), "r+") as fp:
self.features = json.load(fp)
with open(os.path.join(self.save_dir, "page_links.json"), "r+") as fp:
self.page_links = json.load(fp)
print("Got memoized features ", self.features.keys())


def scan(self, start=None, repeat=0):
print("On depth: ", repeat)
"""Start scanning from a specific word, or from internal database

Args:
start (str): the term to start searching from, can be None to let
algorithm decide where to start
repeat (int): the number of times to repeat the scan
"""
while repeat >= 0:

# should check if start page exists
# and haven't already scanned
# if start in [l[0] for l in self.links]:
# raise Exception("Already scanned")

#iteratively saves in case we get throttled

term_search = True if start is not None else False

# Scan the starting point specified for links
print(f"Scanning page {start}...")
try:
# Fetch the page through the Wikipedia API
page = wp.page(start)
self.features[start] = page.content
print(self.features.keys())
links = list(set(page.links))

#add the links as edges
for i, link in enumerate(links):
if i % 10 == 0:
with open(os.path.join(self.save_dir, "links.npy"), "wb+") as fp:
np.save(fp, np.array(self.links))
with open(os.path.join(self.save_dir, "features.json"), "w+") as fp:
fp.write(json.dumps(self.features))
with open(os.path.join(self.save_dir, "page_links.json"), "w+") as fp:
fp.write(json.dumps(self.page_links))
try:
link = link.lower()
if link not in self.features or link not in self.page_links:
time.sleep(np.random.randint(0, 10))
page = wp.page(link)
self.features[link] = page.content
self.page_links[link] = [l.lower() for l in page.links]
print("Page Accessed: ", link)
else:
print("Page has already been accessed: ", link)
total_nodes = set([l[1].lower() for l in self.links])
for links_to in set([l.lower() for l in self.page_links[link]]).intersection(total_nodes):
self.links.append([link, links_to, 0.1]) # 3 works pretty well
self.links.append([start, link, link_weights[i] + 2 * int(term_search)]) # 3 works pretty well
except (DisambiguationError, PageError):
print("Page could not be retrieved: ", link)

except (DisambiguationError, PageError):
# This happens if the page has disambiguation or doesn't exist
# We just ignore the page for now, could improve this
# self.links.append([start, "DISAMBIGUATION", 0])
print("ERROR, I DID NOT GET THIS PAGE")
pass

repeat -= 1
start = None

The resulting graph (charting just the most popular nodes, with popularity defined by the number of outlinks, and their edges) looked like this (beginning from the page “Political Philosophy”):

A visualization of the top 100 nodes in the graph

Preprocessing

Now, we have lists of nodes, edges, and dictionaries of page contents, all written to their own files. However, there is still quite a bit of preprocessing that needs to be done in order for us to begin applying GNN techniques. First, we needed to take the page contents, and tokenize them:

Tokenizing using Doc2Vec

We’ll see this pattern emerge several times in this blog post, but the best way to explain Doc2Vec is to first explain what it came from — word2vec. The core problem all these representations are trying to solve is: how do you represent words in a text document numerically? The naive approach is one hot encoding. Say you have the sentence: “I like pizza.” The word “I” would be encoded as [1, 0, 0], the word “like” as [0, 1, 0], and the word “pizza” as [0, 0, 1]. While this gets the job done, it has the problem of not capturing the similarity of the words. We’d prefer for similar words to be encoded together. This is where word2vec comes in.

This can be seen in the image below, where the words king, man, woman, and queen have been embedded via word2vec. The relationship between man->woman (i.e the vector that takes you from one to another) is the same as the relationship between king-> queen, as we’d expect.

Vector relations in the word2vec embedding

These embeddings are generated using what’s called a continuous bag of words (CBOW) method. CBOW can be thought of as a sliding window centered at our current word.

Sliding context window on the sentence “i like natural language processing”

Each of these context words is represented as a feature vector, and is used to predict the target word. This is done for each word in the dataset.

An example of how word2vec is trained

Doc2vec was created to scale word2vec to multiple documents. One could just concatenate all the text in the documents together and run word2vec, but that loses out on valuable information from the document groupings. To incorporate document information, doc2vec runs an algorithm very similar to a continuous bag of words, but also passes in the document id as a feature vector.

An example of how doc2vec is trained
   # List of tuples page title, page content
features = dict(filter(lambda x: x[0] in nodes, rg.features.items()))
features = sorted(features.items(), key=lambda key_value: nodes.index(key_value[0]))

#Tokenizing the docs using NLTK, and adding the document tags
tokenized_docs = [nltk.word_tokenize(' '.join(doc).lower()) for doc in features]
tagged_docs = [TaggedDocument(words=doc, tags=[str(i)]) for i, doc in enumerate(tokenized_docs)]

#Building and Training a doc2vec model
model = Doc2Vec(vector_size=300, min_count=1, epochs=50)
model.build_vocab(tagged_docs)
model.train(tagged_docs, total_examples=model.corpus_count, epochs=model.epochs)

#Using the doc2vec model to turn all our pages into vectors
feature_vectors = {node: model.infer_vector(tokenized_docs[i]) for i, node in enumerate(nodes)}
nx.set_node_attributes(G, feature_vectors, name="features")

And that’s it! This is used to create the feature representation of each Wikipedia page, which is then used as a feature vector for each of our nodes.

Creating a PyG Graph

Now, we have all our pieces in place. The process from here to create a PyG graph is actually quite simple! Because of some plotting tools we used for internal development, we actually converted the graph into a NetworkX graph first, and then turned it into a PyG graph. The code for how we did that is provided below:

G = nx.Graph() # MultiGraph()
G.add_nodes_from(nodes)
feature_vectors = use_doc2vec(nodes)
nx.set_node_attributes(G, feature_vectors, name="features")
G.add_edges_from(links)

#After some plotting, convert the nx graph into a PyG graph
PyG_Graph = from_networkx(G)

Model

Now that we have the dataset properly created, we can begin applying Graph ML techniques. As mentioned earlier, we are framing this problem as a link prediction problem. Link prediction is a common framing for Graph ML problems, often used for Knowledge Graph Completion tasks.

However, there are many ways that one could go about applying Graph ML to this task. We studied this survey from NeurIPS, walking through different approaches to Link Prediction problems. The paper highlights SEAL as the best performer, however, it relies upon Graph-level encodings in its methodology. As we hope for our graph to grow over time (as we explore more and more of the Wikipedia dataset), we expect our graph-level encoding to dramatically change. Thus, we looked further for an approach that could operate in an inductive training setting. Luckily, the second-best-performing model — VGAE — checks that box!

Table comparing different graph architectures on a link prediction task

The VGAE Model — The Theory

So how does it work? To understand Variational Graph Auto-Encoders, it’s helpful to first understand the principles behind Variational Auto-Encoders (and to understand that, it’s helpful to first understand Auto-Encoders). Truly turtles all the way down.

Auto-Encoders

Auto-Encoder is simply the name for a class of neural models that take an input, pass them through some sort of encoder, and then pass that output through some sort of decoder to create an output that resembles the input.

Traditionally, the embedding space Z is a lower-dimensional representation than the original X. And that’s it! The power of AEs is the embedding space Z, which allows one to make small perturbations in the embedding space, and generate outputs X_hat that are slightly different from the inputs X.

Variational Auto-Encoders

However, this leads to one of the limitations of Auto-Encoders. Regardless of how much you perturb the embedding Z, the Auto-Encoder can only generate images that are similar to the original X. Variational Auto-Encoders (VAEs for short), on the other hand, allow you to generate data not explicitly seen in the original dataset? How?

VAEs do this by embedding the input X into a probability distribution rather than a lower-dimensional point in Z. And then, when sampling from Z, one samples from the underlying probability distribution q(Z|X) (probability of Z given X).

If you want to really understand VAEs, there are other things to know, such as training with noise and the reparameterization trick. There are countless tutorials online that go into these subjects in detail.

Variational Graph Auto-Encoders

Now that we understand VAEs, we can start tackling the mystical VGAE! In one sentence, a VGAE takes the ideas behind a VAE, and applies it to graph-structured data (like a Wikipedia graph!). The first obstacle is that graphs tends to be unstructured, meaning it has unordered nodes, variable sizes, differing numbers of neighbors, and so on. So step one is representing our graph as an Adjacency Matrix.

Perfect! Next, we need to create the encoder. In a VGAE, the encoder is a two-layer graph convolutional network (explained in further detail here). As an input, it takes the Adjacency Graph A and a Feature matrix X (in our example, that’s the Doc2Vec features for each node).

The first GCN layer serves our purpose of dimensionality reduction, and generates a lower-dimensional feature matrix. Mathematically, it looks like this:

The second layer of the GCN creates the two variables that characterize the probability distribution, mu and sigma. Mathematically, that looks like this:

For the decoder, we simply compute the inner product between the latent variable Z (quite simple!). This can be seen mathematically as:

Where A is the reconstructed Adjacency matrix, and the sigma function stands for the logistic sigmoid function.

We then calculate the loss using the formula below:

You can break the equation above into two parts. The first part is the reconstruction loss between the input Adjacency Matrix A and the output Reconstructed Adjacency Matrix, and the second term is the KL divergence term, which you can read more about here: https://towardsdatascience.com/light-on-math-machine-learning-intuitive-guide-to-understanding-kl-divergence-2b382ca2b2a8. And that’s all there is to it! We have our graph architecture in place.

The VGAE Model — In Practice

Now that we have the theory out of the way, let’s take a look at the code that makes it work in practice

class VGAE_Encoder(nn.Module):
"""Two-layer GCN encoder as described in the VGAE paper."""
def __init__(self, num_features, latent_dim=16):
super(VGAE_Encoder, self).__init__()
hidden_layer = 2*latent_dim
self.conv_1 = GCNConv(num_features, hidden_layer)
self.conv_2_mean = GCNConv(hidden_layer, latent_dim)
self.conv_2_var = GCNConv(hidden_layer, latent_dim)

def forward(self, features, edge_index):
h = F.relu(self.conv_1(features, edge_index))
mean = self.conv_2_mean(h, edge_index)
var = self.conv_2_var(h, edge_index)
return mean, var

As the decoder is simply an inner product (and is this not a trainable layer), we calculate it while calculating the loss, which occurs in the training step.

Training the Model

And below is the code for training!

for run in range(args.runs):
model = VGAE(VGAE_Encoder(num_features), decoder=None)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Training loop
for epoch in tqdm(range(args.epochs)):
model.train()
optimizer.zero_grad()
z = model.encode(data.x, data.train_pos_edge_index)
loss = model.recon_loss(z, data.train_pos_edge_index, neg_edge_index=None)#0.01*model.kl_loss(
loss.backward()
optimizer.step()

# Log validation metrics
if epoch % args.val_freq == 0:
model.eval()
with torch.no_grad():
z = model.encode(data.x, data.train_pos_edge_index)
auc, ap = model.test(z,
data.val_pos_edge_index, data.val_neg_edge_index)
tqdm.write('Train loss: {:.4f}, Validation AUC-ROC: {:.4f}, '
'AP: {:.4f} at epoch {:03d}'.format(loss, auc, ap, epoch))

As mentioned before, the loss (specifically the recon_loss function) is what acts as the decoder. We can look inside that function to see how it works:

def recon_loss(self, z: Tensor, pos_edge_index: Tensor,
neg_edge_index: Optional[Tensor] = None) -> Tensor:
"""Given latent variables :obj:`z`, computes the binary cross
entropy loss for positive edges :obj:`pos_edge_index` and negative
sampled edges.

Args:
z (Tensor): The latent space :math:`\mathbf{Z}`.
pos_edge_index (LongTensor): The positive edges to train against.
neg_edge_index (LongTensor, optional): The negative edges to train
against. If not given, uses negative sampling to calculate
negative edges. (default: :obj:`None`)
"""
pos_loss = -torch.log(
self.InnerProductDecoder(z, pos_edge_index, sigmoid=True) + EPS).mean()

if neg_edge_index is None:
neg_edge_index = negative_sampling(pos_edge_index, z.size(0))

neg_loss = -torch.log(1 - self.InnerProductDecoder(z, neg_edge_index,
sigmoid=True) + EPS).mean()

return pos_loss + neg_loss

Namely, we calculate a loss for the negative edges and positive edges separately (so that the loss isn’t weighted by an imbalance in either category), and add them together.

Results

To evaluate our model’s performance, we used the ROC area under the curve (otherwise known as ROC-AUC) metric. This calculates the area under the curve for the true-positive rate to the area under the corve for the false-positive rate. In our link prediction setting, it shows the ratio to which our model can distinguish between a positive edge (i.e an edge that exists between two nodes) and a negative edge. This is ideal for our setting, as we’re not trying to recommend all the possible articles that might be relevant — instead, we want to make sure that the articles we do recommend are relevant. Let’s see how our model performs:

Loss Curves over time
The AUC-ROC Curves on train and validation

To dive in deeper, please check out our github here: https://github.com/QuantumArjun/Augmented-Notes-GNNs

Trying it out on real-world notes

Yay! We now have a model that’s able to fairly accurately predict links on the Wikipedia dataset, using both the graph structure and the page content (represented as doc2vec features). We can now attempt our original task — given a random note, can our model give us some Wikipedia articles that might be related? This part of the evaluation is fairly subjective, as our notes are not in the original Wikipedia dataset. To do this, we simply take our new note, turn its contents into features using doc2vec, and append it to our graph. We then ask the model to predict the likelihood of an edge between our new node and every other node in the graph. We take the 3 highest probabilities and return those.

Let’s see how it performs!

Note: Modernization Theory

Recommendations

Note: Democracy as a Universal Value

Recommendations

Note: Lee Kwan Yu (the former Singaporean Prime Minister)

Recommendations

At first glance…these look amazing. They’re all related to the original note in some way, and are interesting extensions of the topic. This post is getting a tad lengthy at this point, so we’ll reserve comments, and let you read through and judge for yourself!

Next Steps

So, what’s next? On the data side, we’d obviously like to train and test our graph on all of Wikipedia. To do this, we could either use our approach and parallelize across multiple CPUs, or download the Wikipedia Dump and process the information that way.

This would also allow us to try more performant architectures, such as the SEAL architecture mentioned earlier, as our training graph would remain relatively static at that point, and we could compute graph-wide embeddings.

On the practical side, we could tinker with which articles we show the user. Some of the articles (Such as Democracy for the note “Democracy as a Value”) are a little on the nose, and don’t fit the exploratory goal we have. And lastly, it would be such a useful feature (for us at least) to integrate this into our own note-taking process!

Find our implementation here: https://github.com/QuantumArjun/Augmented-Notes-GNNs

Unfortunately, due to issues in loading graphs into colab, you will have to follow the GitHub instructions to train/test our model. The colab below just walks through how we created our dataset.

Find our Google Colab (which walks through a subset of our code) here: https://colab.research.google.com/drive/1EojTdUDdM-NuFIjveF-6XeADb1msxqLO?usp=sharing

--

--