Self-Supervised Learning For Graphs

By Paridhi Maheshwari, Jian Vora, Sharmila Reddy Nangi as part of the Stanford CS 224W course project.

A large part of deep learning revolves around finding rich representations of unstructured data such as images, text and graphs. Conventional methods try to find these representations using some end goal we want to perform. This is typically done in a supervised setting where we have labelled data. However, in many real-world applications, we do not have labels associated with data, but instead we have an abundance of unlabelled data.

In self-supervised learning, the aim is to use large amounts of unlabelled data and learn embeddings by identifying underlying structures or patterns in data. If we can learn good representations using only unlabelled data, then we can perform a variety of downstream tasks by using the pretrained embeddings and adding only a few final layers based on the task, which can be trained with relatively lesser labelled data. On the other hand, in supervised learning, the entire model is trained end-to-end with labels and requires a lot of training data. Pictorially, this can be depicted as follows:

The above image depicts the difference between traditional supervised methods where data is labelled versus self-supervised training where the latent space is learnt using different data augmentation strategies. Here, we show the application of self-supervised learning for the image classification task. (source)

Ideally, we want to learn embeddings such that data points which are “similar” to each other have embeddings close to each other and embeddings of “dissimilar” data points are far from each other, for a suitable distance function in high-dimensional space (such as L2 norm or cosine similarity).

Self-supervision methods exploit the above idea in order to learn useful embeddings from unlabelled data. For a given datapoint, we create multiple “views” or “augmentations” of it - one of which is similar to the given data point and hence, forms a part of the positive pair and another view which is far from the given datapoint. The model is trained using a loss function which forces the embeddings of the positive pair to be close to each other and simultaneously pushes the embeddings of the negative pair away from each other. This idea is known as Contrastive Self-Supervised Learning.

The above visualization explains how we want to attract embeddings of similar image views and repel the embeddings obtained from images unrelated to each other. (source)

Self-supervised learning has shown a lot of promise in image and text domains. Currently, self-supervised methods are employed to learn generally useful representations, which help in building state-of-the-art models across a variety of tasks. In fact, these techniques are also shown to beat supervised models on certain tasks.

The above image shows the promising numbers of self-supervised training for the task of image classification on the ImageNet dataset. We can see that SimCLR gets very close to supervised methods even without needing labels to train the entire model end-to-end. (source)

With this high-level introduction to self-supervised learning, let’s see how we can apply this technique to graphs. In the remainder of the blog, we will explore the following questions, along with code snippets in PyTorch Geometric (a popular framework for Graph ML built on top of PyTorch):

  1. How to create different “views” or “augmentations” of a graph?
  2. How to build models that operate on graphs?
  3. How to train these models in a self-supervised manner?
  4. How to leverage self-supervised learning for downstream tasks?

You can also leverage our Google Colab and Github repository for training self-supervised models from scratch and testing on the downstream task of graph classification.

Data Augmentations

In this section, we will explore how to create perturbations in graphs in order to form positive pairs such that we expect the embeddings of these graphs to be close to each other. On a high level, we want these augmented graphs to have a similar structure to the original graph. Some of them try to change the structure of the graph while other keep the graph structure intact but perturb the node features in order to make the representations invariant to initial node attributes. We now provide a list of augmentations that can be performed to create positive graph pairs.

The above figure illustrates different data augmentation techniques for graphs.

1. Edge Perturbation: In these augmentations, with a small probability, we randomly add or drop edges in an existing graph to create new graphs. We also have a maximum fraction of edges that we want to perturb so that we do not end up changing the fundamental structure of the graph. The code snippet for the following view generation goes as follows:

2. Diffusion: In these augmentations, the adjacency matrix is transformed into a diffusion matrix using a heat kernel which provides a global view of the graph as opposed to the local view provided by the adjacency matrix.

3. Node Dropping: In these augmentations, we randomly drop a small fraction of the nodes to create new graphs. All edges linked to that particular node also get deleted. The code snippet for the following view generation is:

4. Random Walk based Sampling: In these augmentations, we perform a random walk on the graph and keep on adding nodes till we reach a fixed pre-decided number of nodes and form a subgraph out of these. By random walk, we mean that if you are currently at a node, then you traverse an edge from the node at random. The code snippet for the following view generation is:

5. Node Attribute Masking: In these augmentations, we mask out features of some nodes to create an augmented graph. The mask here is created by sampling each entry of the mask from a gaussian of pre-specified mean and variance. The hope is to learn representations which are invariant to node features and depend mainly on the structure of the graph.

Graph Neural Networks

Graph Neural Networks (GNNs) are a class of deep learning models that operate on graph inputs. These networks have gained immense popularity in recent years because of their wide applicability in domains such as knowledge graphs, social networks and molecular biology. They can be used to learn embeddings for graph entities (nodes, edges, or entire graphs) and have shown remarkable performance on tasks like graph classification, node classification, link prediction and so on.

This figure explains the mechanics behind a single layer of Graph Neural Networks (source)

This step constitutes a single layer of GNN and it is repeated multiple times by stacking GNN layers. As we increase the number of layers, we increase the receptive field of every node by including neighbors of neighbors and so on. In a K-layer GNN, every node has a receptive field of its K-hop neighborhood.

This figure shows how a stack of GNN layers operate (source)

Let us delve deeper and look under the hood. There are two main operations in a GNN layer - aggregate messages from all neighboring nodes, and combine with the previous embedding of the given node. Mathematically,

Here, the context is propagated through the graph via its edges. An interesting takeaway of this formulation is that it enables the GNN to operate on graphs of arbitrary shapes and sizes. It also ensures the update step is invariant to the ordering of nodes, a property commonly known as permutation invariance. Most existing GNN algorithms can be written in the above format, and different choices for aggregate and combine functions yield different GNN models. We will now discuss some popular architectures along with their easy implementations using PyG.

  • GraphSAGE [1]: One of the first works to propose an inductive learning framework for graphs, i.e, the ability to generalize to unseen nodes during inference. While the original paper presents various options for the aggregate function, the most commonly used is the pooling aggregator. This transforms every node embedding using a fully connected layer, and performs an element-wise max-pooling operation to aggregate information from all neighbors. The combine function first concatenates the output of aggregate and previous embedding, and passes through another fully connected layer with ReLU activation.

This can be implemented as:

  • Graph Convolutional Network (GCN) [2]: The aggregate and combine functions are integrated into a single update equation shown below. The primary difference from GraphSAGE is the use of a symmetric mean pooling or normalization, instead of max pooling.

This can be implemented as:

  • Graph Attention Network (GAT) [3]: This method implements a non-isotropic aggregation, where importance scores (or attention) are assigned to each edge in the graph. This determines the contribution of neighboring nodes to update the embedding of a given node. An advantage of this method is that the learned attentions can also be used to analyze and interpret the graph better.

This can be implemented as:

  • Graph Isomorphism Network (GIN) [4]: A simple yet powerful model which generalizes the WL isomorphism test (refer to this blog for more information) for graphs. It essentially uses a sum aggregator instead of max or mean aggregators, which can capture the entire multi-set of nodes in a graph.

This can be implemented as:

  • Simple Graph Convolution (SGC) [5]: This work hypothesizes that the non-linearity in every GCN layer is not critical, and the majority of benefit arises from neighborhood aggregation. So, they remove the activation functions between layers.

Resulting K-layer network becomes linear in the weight parameters W, but has the same receptive field as that of a K-layer GCN. This can be implemented as:

Once we have the embeddings for all nodes, we can leverage various pooling techniques (max, sum, mean) to obtain the embedding for the entire graph. While all networks have their advantages and disadvantages, there is no single best network that surpasses others. The best choice for model architecture varies based on the dataset and task at hand, and is often an empirical decision.

Now, like any other machine learning model, we need to define an objective to train our GNN. In a supervised setup, where we have access to ground-truth labels, we can simply train the network to optimize standard loss functions (such as cross entropy for classification). In the case of self-supervised learning, we want to learn a latent space such that positive (similar) samples are closer, and negative (dissimilar) samples are further apart. This brings us to the contrastive losses that can enforce this constraint.

Contrastive Losses

This figure gives an example of how data augmentation techniques can be applied at random to generate positive graph pairs for self-supervised learning.

The goal is to score the agreement between positive pairs higher than the negative pairs. For a given graph, its positive is constructed using the data augmentation methods discussed earlier, and all other graphs in the mini-batch constitute as negatives. Our self-supervised model can be trained using the InfoNCE objective [6] or the Jensen-Shannon Estimator [7].

While the derivation of these metrics is beyond the scope of this blog, the intuition behind these is rooted in Information Theory such that these metrics try to efficiently estimate the mutual information between views. The code snippet to implement these is as follows:

We can now piece together the building blocks we have seen so far and train our models without any labelled data. For a more hands-on experience, please refer to our Colab Notebook which combines the various techniques for self-supervised learning. We provide an easy-to-use interface for training your own models, along with the flexibility to try out different augmentations, GNNs and contrastive losses. Our entire codebase can be found on Github.

Downstream Tasks

So far, we have learnt about the various steps involved in self-supervised learning that can potentially result in better embeddings for the graph data. In this section, we will see how this applies to real-world downstream tasks.

The above image illustrates the task of classifying molecular graphs into multiple classes for smells (source)

Let us consider the task of Graph Classification, which refers to the problem of classifying graphs into different classes based on some structural graph properties. Here, we want to embed entire graphs in a way that they are separable in the latent space given a task at hand. Our model includes a GNN encoder coupled with a classifier head as shown in the code snippet below:

Dataset: The TU Dortmund University has collected a wide range of different graph datasets, known as the TUDatasets, which are accessible via torch_geometric.datasets.TUDataset in PyG. We shall experiment on one of the smaller datasets, MUTAG. Each graph in this dataset represents a chemical compound and it also has an associated binary label that represents their “mutagenic effect on a specific gram negative bacterium”. The dataset includes 188 graphs with 18 nodes, 20 edges on average for each graph. We intend to perform binary classification on this dataset.

Data Preprocessing: We split the dataset into 131 train, 37 validation and 20 test graph samples. We also add additional features to each node by representing node degrees as one-hot encodings. Conventional hand-crafted features like node centrality, clustering coefficients and graphlet counts can also be included to get richer representations.

Training: When trained with a GCN encoder using Cross Entropy Loss and Adam Optimizer, we achieve a classification accuracy of 60%. The accuracy is not very high because of the limited amount of labelled data.

Now, let’s see if we can use the previously learnt self-supervision techniques to improve the performance. We can use multiple data augmentation techniques like Edge Perturbation and Node Dropping, to train the GNN encoder independent and learn better graph embeddings. Now, the pretrained embeddings along with a classifier head can be fine-tuned with the available labelled dataset. When tried on MUTAG dataset, we observed that the accuracy jumps to 75%, a whopping 15% improvement from before!

We also visualize the embeddings from our pretrained GNN encoder in a low-dimension space. Even without access to any labels, the self-supervised model is able to separate the two classes, which is a remarkable feat!

t-SNE plots of the pretrained embeddings on MUTAG dataset, where color represents the class of the graph.

Here are some more examples where we performed the same experiment on different datasets. Note that all of them are trained on the GCN encoder, with self supervision applied through Edge Perturbation and Node Dropping augmentations and InfoNCE objective function.

Comparison of classification accuracies with and without self-supervised pretraining on multiple datasets.

This self-supervised pretraining is very effective, especially in cases when we have limited amount of labelled data. Consider a setup where we have access to only 20% of the labelled training data. Once again, self-supervised learning comes to our rescue and boosts the model performance significantly!

Comparison of classification accuracies with and without self-supervised pretraining on multiple datasets. Here, we use only 20% of the labelled data for training.

To experiment with more datasets and self-supervision techniques, follow the instructions in our Google Colab or the Github repository for this work.

Conclusion

To sum up this blog, we learnt about self-supervised learning for graphs by understanding different data augmentation techniques and their integration into graph neural networks through contrastive learning. We also saw a significant improvement in performance on the task of graph classification.

Recently, a lot of research is focussed on finding the right augmentation strategies for learning better representations for various graph applications. Here, we have summarized some of the most popular methods exploring self-supervised learning for graphs. Happy reading!

Popular methods for contrastive self-supervised learning on graphs.

To conclude, you need not have more labelled data to improve the model performance. Just get more data and self supervision does the job!

References

  1. Hamilton, Will, Zhitao Ying, and Jure Leskovec. “Inductive representation learning on large graphs.” Advances in neural information processing systems 30 (2017).
  2. Kipf, Thomas N., and Max Welling. “Semi-supervised classification with graph convolutional networks.” arXiv preprint arXiv:1609.02907 (2016).
  3. Veličković, Petar, et al. “Graph attention networks.” arXiv preprint arXiv:1710.10903 (2017).
  4. Xu, Keyulu, et al. “How powerful are graph neural networks?.” arXiv preprint arXiv:1810.00826 (2018).
  5. Wu, Felix, et al. “Simplifying graph convolutional networks.” International conference on machine learning. PMLR, 2019.
  6. Oord, Aaron van den, Yazhe Li, and Oriol Vinyals. “Representation learning with contrastive predictive coding.” arXiv preprint arXiv:1807.03748 (2018).
  7. P. Veličković, W. Fedus, W. L. Hamilton, P. Lio`, Y. Bengio, and D. Hjelm, “Deep graph infomax,” in International Conference on Learning Representations, 2019.
  8. You, Yuning, et al. “Graph contrastive learning with augmentations.” Advances in Neural Information Processing Systems 33 (2020): 5812–5823.
  9. Qiu, Jiezhong, et al. “Gcc: Graph contrastive coding for graph neural network pre-training.” Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2020.
  10. Hassani, Kaveh, and Amir Hosein Khasahmadi. “Contrastive multi-view representation learning on graphs.” International Conference on Machine Learning. PMLR, 2020.

--

--