For graphs, why use Graph Neural Networks instead of traditional Deep Learning frameworks?

Preeti Singh Chauhan
4 min readNov 28, 2022

--

Did you also wonder that ever?

Say, you have already applied all graph algorithms to your graph problem. Now you want that extra performance by using representational learning on a graph and you have already implemented other deep learning models like RNN, CNN etc. millions of times with structured/unstructured data. Why don’t you use these traditional deep learning approaches with graphs? Why go for graph-specific deep learning framework works like GNN, GCN, GAT or Graphsage?

The answer is simple. Current traditional deep-learning approaches are built for simple sequences and grids, not for complex structures like graphs!

Let’s reverse engineer this and apply the traditional neural network to graph structure and then arrive at why more graph-centric models are needed.

Let’s say you have a graph G,

where:

  • V is the vertex set
  • A is the adjacency matrix
  • X is the node feature matrix
  • v is a node in V
  • N(v) is the neighbourhood of v
  • Node features could be- publishers profile in publication network, User profile in social network and similar.

Now let’s feed this to a neural network. For that to happen, first:

  • Take your network and represent it as adjacency matrix A
  • Merge adjacency matrix A and node features matrix X
  • Feed them to deep neural network as training set
  • The number of inputs in the input layer is the sum of the number of nodes and the number of features
image credit- Jure Leskovec

Now, let’s see what are the problems with this idea:

  1. Problem 1: O|V| parameters- We have 1 training example per node, but we have more parameters than training examples. This will make the training very unstable and will easily overfit.
  2. Problem 2: This framework is not applicable to graphs of different sizes-So if I train on this graph on a graph of 5 nodes and tomorrow if the graph expands or we want to transfer learn to a graph of 10 nodes, then we wouldn’t know how to fit 10 nodes to 5 inputs
  3. Problem 3: Sensitive to node ordering- Since the graph is permutation invariant, if I change the ordering of the neighbours of any node, say A, then the whole adjacency matrix will change. That is nodes and columns will be permuted, although the information is the same. In that case, the model will be totally confused and wouldn’t know what to output.

How GNN tackles these above-mentioned problems?

You take the idea from Convolutional Neural Network(CNN) of collecting neighbourhood information by taking n*n pixels window and applying some transformation to it and representing it as another cell from CNN(and slide and repeat). And extend this idea to graphs.

image credit — Jure Leskovec

So, if we are going to predict for any node say, v then what we have to do is:

  • determining computational graph for the node i(neighbours, neighbours of neighbours, how many levels?)
  • Starting from the outer layer of the computational graph, take all the inputs(messages) from all the neighbours of a node and transform
  • aggregate and form a new message
  • pass the message to the node, repeat the same for the next level, and the next, till all the messages are aggregated and reach node i

This is just 1 layer of computation. But, we can unfold this for multiple layers.

Now you see, how it solves the above-mentioned problems.

  • The complexity of the calculation is reduced since for a node you’ll have to calculate only for the size of the computational graph around it.
  • Since each node computes its own neural network, the size of the graph doesn’t matter. Even the learning can be transferred to a larger graph.
  • the order in which the messages are aggregated(use functions like sum(), mea(), max() and similar), is of no matter.

Note: Point to note here is that each node in the graph gets to define its own neural network formed by its computational graph around that node.

This is how a GNN works in a nutshell. I’ll deep dive into GNN in another story. GATs, GANs, GraphSage work on the same intuition with some variations.

I hope it is clear now why we do not use traditional deep learning architectures on graphs and why graph neural networks and their variants are more efficient. Look out this space for more knowledge graph concepts and their applications.

--

--