# OhMyGraphs: GraphSAGE and inductive representation learning

Published in

This post assumes you know a little something about graphs and their role in graph neural networks. It’s the first in a series of cool graph neural networks/graph representation learning papers I’ve come across!

# What is GraphSAGE?

GraphSAGE [1] is an iterative algorithm that learns graph embeddings for every node in a certain graph. The novelty of GraphSAGE is that it was the first work to create inductive node embeddings in an unsupervised manner! Just like in NLP, creating embeddings are very useful for downstream tasks. GNNs can use node embeddings for various tasks including node classification, link prediction, community detection, network analysis, etc.

The need for GraphSAGE:

Prior to GraphSAGE, most node embedding models were based on spectral decomposition/matrix factorization methods.

The problem? Matrix Factorization methods are inherently transductive! Simply put, a transductive method does not perform well on data it’s never seen before. IE, these methods would expect the entire graph structure to be present on train time to generate node embeddings. If a new node is added to the graph at a later time, the model would have to be retrained.

Conversely, an inductive approach would be one that can generalize to unseen data — obviously more useful, right? Let’s dig into the intuition behind GraphSAGE.

# The main idea behind GraphSAGE:

You are known by the company you keep.

In the graph above, if you’re a 90s kid, it’d be pretty easy for you to guestimate who Fred, Velma, Daphne and Shaggy are all connected to. If you guessed Scooby, it’s because you realized that whoever the middle node was, had to have a relationship with all the neighbouring nodes. What you secretly did in your head was, you approximated a representation for the Scooby Doo node based on his neighbouring nodes!

# The real GraphSAGE

The goal of GraphSAGE is to learn a representation for every node based on some combination of its neighbouring nodes, parametrized by h.

Recall, every node can have their own feature vector which is parameterized by X. Let’s assume for now that all the feature vectors for every node are of the same size. One layer of GraphSAGE can be run for k iterations — therefore, there is a node representation h for every node, at every k iteration.

Observe the following notation:

Since every node can be defined by their neighbours, the embedding for node A can be represented by some combination of its neighbouring node embedding vectors. Through one round of the GraphSAGE algorithm, we will obtain a new representation for node A. The same process is followed for all the nodes in the original graph.

The GraphSAGE algorithm follows a two step process. Since it is iterative, there is an initialization step that sets all the initial node embedding vectors to their feature vectors. (k would start iterating from 1…K)

1. Aggregate.

Aggregate neighbouring node representations for our target node. The f_aggregate function is a placeholder for any differentiable function. This could be as simples as an averaging function or as complex as a neural network. The below equation translates into:

Aggregate all the embedding vectors for all the nodes u that are in the immediate neighbourhood of my target node, node v. This results in the aggregated node representation for node v as a_v:

2. Update.

After obtaining an aggregated representation for node v based on its neighbours, update the current node v using a combination of its previous representation and the aggregated representation. The f_update function is a placeholder for any differentiable function which, can once again, be as simple as an averaging function, or as complex as a neural network.

The below equation translates into:

Create an updated representation for node v based on its neighbourhood aggregated representation and node v’s previous representation:

Now, the basics behind GraphSAGE involve aggregation and updating node representations — but what about this k hyperparameter? The k-parameter tells the algorithms how many neighbourhoods or how many hops to use to compute the representation for node v.

To illustrate, observe the image below. Instead of initializing node B’s representation to its feature vector, we can actually just run this aggregate-update function to get a representation for node B based on its neighbours. We can do the same for nodes C and D in the k=1 layer. In the k=0 layer, we would initialize the neighbours node embedding to its initial feature vectors.

In the above example, we simply set k=2 and use the neighbours and the neighbours neighbours of node A to get the final target node representation. You could potentially experiment with using multiple neighbourhoods, ie, larger values of k. However, too many neighbourhoods may dilute down the node representation for node v but too few (less than 2) might be similar to not using GNNs and just going with an MLP instead — food for thought!

Great! So now, we should have no problem in understanding the following algorithm snip from the original paper:

Some things to note about the paper’s implementation:

• Line 4: The authors experiment with a variety of aggregator functions including using max-pool, mean aggregation and even LSTM aggregation. The LSTM aggregation method required the nodes to be shuffled every k-iteration so as to not temporally favour any one node when computing the aggregation.
• Line 4: What we generalized as f_aggregate is actually represented as AGGREGATE_k in the paper.
• Line 5: The f_update function in the paper was a concatenation operation. Therefore after concatenation, the shape of the output was of dimensionality (2F,1). The concatenated output undergoes a transformation by matrix multiplication with a weight matrix W^k. This weight matrix is intended to reduce the dimensionality of the output to (F,1). Finally, the concatenated and transformed node embedding vector undergoes a non-linearity.
• Line 5: There is a separate weight matrix for each k-iteration. This has the interpretation of learning weights that have a sense of how important multiple neighbourhoods are to the target node.
• Line 7: The node embedding is normalized by dividing by the vector norm to prevent gradient explosion.

# Unsupervised loss function

So, how does one actually train a GraphSAGE GNN?

The authors train both unsupervised and supervised GraphSAGE models. The supervised setting follows a regular cross-entropy style prediction for a node classification task. The unsupervised case however tries to preserve graph structure by enforcing the following loss function:

The blue portion of the loss function tries to enforce that if nodes u and v are close in the actual graph, then their node embeddings should be semantically similar. In the perfect scenario, we expect the inner product of z_u and z_v to be a large number. The sigmoid of this large number gets pushed towards 1 and the log(1) = 0.

The pink portion of the loss function tries to enforce the opposite! That is, if nodes u and v are actually far away in the actual graph, we expect their node embeddings to be different/opposite. In the perfect scenario, we expect the inner product of z_u and z_v to be a large negative number. This can be interpreted as, the embeddings z_u and z_v are so different that they are greater than 90 degrees apart. The product of two large negatives become a large positive number. The sigmoid of this large number gets pushed towards 1 and the log(1) = 0. Since there are potentially more nodes u that are far from our target node v in the graph, we sample only a few negative nodes u from the distribution of nodes far away from node v: P_n(v). This ensures the loss function is balanced when training.

The addition of epsilon ensures we never take log(0).

# TL,DR: GraphSAGE

GraphSAGE is a way to aggregate neighbouring node embeddings for a given target node. The output of one round of GraphSAGE involves finding new node representation for every node in the graph. Several stacked layers of GraphSAGE can create complex, structural and semantic level features for any downstream task!

In a future post, I will implement the GraphSAGE for simple tasks, like node classification. Stay tuned!