Simplifying Graph Learning (GNNs): Understanding the Basics and Their Applications.

AMA
10 min readMar 2, 2024

--

If you are reading this, I guess you just started your journey into GNNs which stand for Graph Neural Networks and somehow the subject looks complex so you would love to have a simplified explanation of what they are and how they work.

Perhaps you may be familiar with the topic but need a refresher. Well, my journey into GNNs started with my dissertation in which I was interested in developing an inclusive recommendation system using GNNs with heterogeneous data. However, this is not the topic of this article (which will be covered at the end of this series on GNN). Doing research is quite interesting, intriguing, and sometimes frustrating because this journey took all over the internet from YouTube videos of Standford’s Lecture on GNNs by Juke Luskevoc CS224W, Aleksa Gordic's Youtube playlist on GNNs, books like Graph Representation learning by William H., Hands-on GNNs Using Python by Maxime L., and a lot of other resources.

However, I realized that GNNs are kind of complex and with a lot of writing on the subject highly technical and sometimes can be like a physician talking to someone from an arts background just because the Lingua Franca of GNN is itself complex still. Thus the goal of this series of articles is to make GNNs intuitive and second nature by providing you with a mental model to visualize and approach graph problems using GNNs.

Why bother studying GNNs

It is always interesting to acquire new knowledge but why bother to go through the hassle of learning about GNNS. It is simple GNNs have a wide variety of applications that are at the heart of some of the apps we use daily and are at the forefront of some groundbreaking scientific advancements in the last decade, some of which include

  • Alpha fold by Deep Mind to solve the protein folding challenge which could accelerate the discovery of new drugs.
  • Google Maps ETA (Estimated Time of Arrival), which enables most of us daily plan our commute time and route daily so as not to be late for an interview, business meeting, or date.
  • Pinterest recommendation systems that recommend relevant Pins to users based on their topic of interest.
  • Determining the Effects of different interactions of drugs taken together by patients during treatment to enable efficient medication recommendation.
  • Fraud prevention for banks and fintech by detecting fraudulent behaviors and activities on customer accounts.

I hope now that you have a glimpse of some of the problems that can be solved using GNNs as illustrated above this has stimulated your curiosity if this is your first time learning about GNNs. While with those with domain knowledge this is like deja vu but the possibilities of GNNs are kind of limitless given the amount of relational data available today in nearly all domains and across industries, disciplines and sectors. In short, most problems can be solved using GNNs if we can formulate the problem as Graphs. So what are Graphs and how can we formulate a problem into a graph?

One of my best description is that of Juke L. (2021) which states that graph are a general language of analyzing entities with relation or interaction. Thus, what this means is graph is composed of c or vertices which could be an object that are connected to each other through edges, thus giving a structural relationship of mapping entities and their relationships.

Figure 1.1 Illustration of a Graph

From the defiation and illustration above we can already conceive how many underlying types of graph around us such as family tree, disease pathway, computer networks, food web, transportation networks. Thus with this powerful underlying structure of graphs we can analyze entities and their relationships to uncover insights about their underlying structures and patterns.

Fig 1.2 Jure Leskovec, Stanford CS224W: Machine Learning with Graphs
Fig 1.2 Jure Leskovec, Stanford CS224W: Machine Learning with Graphs

We observe from the examples above that the graphs in this domain are explicit as they simply just pop out or are intuitive. However, because of the versatile nature of graphs, we can formulate problems as graphs implicitly in other domains where there are fewer structural relationships between entities. As we can see in the image below, an image can be represented as a graph where pixels are nodes and edges between nodes are used to map relationships and can be used for graph-based image and computer vision tasks. Moreso, we also have graphs in domains such as NLP involving text analysis and generation. However, in a context where the underlying structure is dynamic and evolving where we do not have a fixed structure, such that the nodes and edges keep changing we can use other graphs and machine learning to derive valuable insights.

Fig 1.3 Spatial temporary Graphs

Graph Machine Learning Tasks

So probable one of the questions which should come to mind by now is how do we understand complex domains that have relational data but which is not explicit. Thus, we can to answer the question of how do we take advantage of relational structure for better predictions. This is where graph machine learning comes into play, which is simple the application of machine learning on graph data. When I talk about use machine learning for predictions of graph, what I am simple referring to is either use ML of graph to predict links between pairs of graphs, node labels, generate graphs and subgraphs. The predictions mentioned previously are the some of the most common machine learning tasks on graphs are include.

  • Node classification: This is simple same as classification problems in machine learning and in the case of graphs, this involves predicting the class of a node in a graph. In the case of a movie and actor graph, this can categories movies based on their characteristics into either Horror, Action or Thriller, while actors into different awards. Thus the model is trained to make predictions on unlabeled nodes based on labeled nodes and their attributes.
  • Relation prediction: This is a common is commonly referred to as link prediction which is very commonly used for recommendation systems when applying machine learning applications on graphs, and it entails predicating missing links between pairs of nodes in a graph. This can also be used in knowledge graphs where the goal is to complete a graph with missing links, by mapping relationship between entities based on their relationship.
  • Graph classification: This is also referred to this as clustering and community detection n this task, the input is a graph, and the objective is to develop a classifier that accurately predicts the class of the graph. Thus this involves categorizing graphs into different classes. This can useful in domains such as molecular biology in which molecules represent representing a substance can be used to predict their properties for drug design.
  • Graph generation: This is process in which a new graph is generated after training on properties of graph data. One of the major application of this is in drug discovery.

In machine learning we are verse the 3 main types of machine learning such as supervised, semi-supervised and unsupervised machine learning which are mostly use to define machine learning problem and use a suitable ML algorithm to propose a solution to the problem. I guess from reading the different types tasks above used on graphs we can pause and take a quick guess of the type of ML to which each belongs.

Supervised learning on Graphs: When using supervised learning we have the labeled data we are trying to predict or classify which is use to train our model. Thus we can see graph task such as Node Classification and Relation prediction belong to supervised learning, given that in Node Classification, the neighboring nodes label and features are use to learn features of surrounding nodes and classify nodes or make predictions base on neighborhood similarity

Unsupervised learning on Graphs: As a quick reminder is in unsupervised learning we have unlabeled data and the objective of the ML algorithm is to uncover hidden patterns and similarities between entities by group them. Thus graph classification task is type of unsupervised learning for graphs as in involves grouping nodes based on their relations using edges between nodes detect communities.

Semi-supervised learning: In this case we have part of the labeled data and the other part is unlabeled. Thus, the ML algorithm is used to train part of the labeled graph and the nodes features learned during training is intend used to generate new graphs based on the underlying similarity of the neighborhood know as structural equivalence in graph machine learning (more on this later).

Graph Learning Techniques

Now that we can formulate the different graph tasks into the different types of machine learnings, let’s proceed to examine the different families of graph learning techniques.

i. Graph signal processing involves the use of traditional signal processing to graphs such as Fourier transform and spectral analysis to grahs, with the objective of discovering intrinsic properties of graphs (connectivuty and structure).

ii. Matrix factorization is used to identify patterns in graphs that explain relationships in the original matrix, given that it seeks to find dimensional representation in a large matrices leading to compact and interpretable representation of data.

iii. Random walk is mathematical technique used to stimilate movements in graphs between entities through their edges. Thus this stimulated movements in the graph enable information gathering about the relationships between nodes.

iv. Deep learning is the use of neural networks with several layers in machine learning task, and is used to effectively encode and represent graph data as vectors.

The techniques mentioned above are very powerful but are not mutually exclusive and are in practice are oftenly combined to form hybrid models.

Graph Data Representation

Now that we have a good foundation of the different graph tasks, machine learning types and graph learning techniques, it is important that we look an the fundamental building block which is obviously the data. How is graph data represented when presented in tabular data as is the case with most of our datasets in comma separated values or sheets. When we use spreadsheets to represent data, in mostly in rows denoting an entity and columns denoting their attributed or features, and if we are using a relational database we denoted the relationships with foreign keys. However, most tabular datasets observations involve analyses of entities though with explicit information encoded which denoted relationship between entities, and this is the where graph datasets become a powerful to map relationships as edges between entities.

From the dataset above we can observe that there are 5 members of the family with 3 features. We can observe from the table above that tabular data does not have edges where as on the graph counterpart it allows us to understand the relationships between the family. Thus the relationship between nodes is vital in gaining deep insights into the data, and is the reason why graph data is increasing in popularity in the deep learning community.

Figure 1.4 Representation of Family tree in Tabular versus a graph dataset.

Building upon our knowledge of graph data and the different types of tasks, notable with deep learning, we would proceed to examine GNNs.

Introduction to GNNs

What are GNNs?

GNNs which stand for Graph Neural Networks are a kind of deep learning architecture which is designed and optimise for graph-structured data. Most traditional data process text, images which have a sequence like pattern, graphs are sparse and heterogeneous and GNNS are explicitly designed to process and analyze this kind of datasets.

How do GNNs work?

GNNs, or Graph Neural Networks, work by creating a vector representation of each node in a graph, utilizing information from various sources. This representation includes the original node features, edge features, and global features. For instance, in a social network, a node’s features could be name, age, and gender, while edge features might represent the strength of relationships between nodes, and global features could be network-wide statistics.

Unlike traditional machine learning techniques, GNNs enrich the original node features with attributes from neighboring nodes, edges, and global features, resulting in a more comprehensive and meaningful representation. These new node representations are then used to perform specific tasks, such as node classification, regression, or link prediction.

GNNs achieve this by defining a graph convolution operation that aggregates information from neighboring nodes and edges to update the node representation. This process is repeated iteratively, allowing the model to learn more complex relationships between nodes as the number of iterations increases.

There are various types of GNNs and GNN layers, each with its unique structure and method of aggregating information from neighboring nodes. When choosing the right GNN architecture for a particular problem, it’s essential to understand the characteristics of the graph data and the desired outcome.

GNNs are most effective when applied to specific, high-complexity problems where learning good representations is critical to solving the task. They require a substantial amount of data to perform effectively. While traditional machine learning techniques might be a better fit for smaller datasets, they do not scale as well as GNNs. GNNs can process larger datasets thanks to parallel and distributed training and can exploit additional information more efficiently, leading to better results.

Source:

--

--

AMA

Data and ML engineer, I love exploring data and building ML systems to production. Love animes, books, music and having deep conversations.