Incorporating Edge Features into Graph Neural Networks for Country GDP Predictions

By Andre Turati, Peter Boennighausen, Rahul Shiv as part of the Stanford CS224W course project

Graph neural networks (GNNs) are an extremely flexible technique that can be applied to a variety of domains, as they generalize convolutional and sequential models that assume a more rigid data structure. In this post, we use an attention-based GNN to predict country-by-country GDP in a given year from international trade flows and nation-specific data.

We describe our approaches in this article, but feel free to follow along in our Google Colab as well.

Setup, Data Sources, and Preprocessing

The first step in any graph machine learning problem is to understand the type of task under consideration. Since graphs are so flexible, we have many degrees of freedom in deciding what is an edge, what is a node, and what is our desired outcome. In our case, however, the setup is relatively straightforward:

  • The nodes are countries
  • The node features are country-level statistics like population
  • The edges are between trading partners
  • The edge features are the quantities of particular items being traded

Our desired outcome is also pretty simple, as we are doing node-level prediction of a continuous random variable, namely GDP. We in effect are learning a predictor for GDP, which we can evaluate using mean-squared error.

In order to gather the data, we relied on two sources. The first is BACI [1], which is a project by the French global economics institute CEPII that captures bilateral trade flows between all countries dating back to 1995. The dataset is quite large and exhaustive, as it contains product-level information for all pairs of countries. The second data source is the World Bank, which collects country-level statistics for the entire world in its role as an international institution dedicated to development. We use this source both for our labels (GDP) and also for the three node-level features we added to the model: employment level, inflation rate, and population.

Methods

Graph Neural Networks Overview

Before diving into any code, it is helpful to review how GNNs actually work. At the simplest level, GNNs attempt to represent nodes, edges, or graphs as vectors, which can then be used for traditional, downstream machine learning tools such as multi-layer perceptrons. Ideally, we would want these vector representations to have some meaningful relationship to the original graph. We hypothesize that by combining the node-level features, the edge-level features, and the relationships of nodes, we can learn a representation that correlates with GDP. GNNs learn to map countries to such vector representations through a technique called ‘message passing, aggregation, and update.’

Figure 1: GNNs use both a node’s features and its relationships with other nodes to find a suitable vector representation. Left: Zachary’s Karate Club Network [6], a graph with 34 nodes that fall into one of two categories: “Mr. Hi” and “Officer.” Middle: An oversimplified diagram of a GNN. Right: Vector representations of the nodes in Zachary’s Karate Club. Notice the separation between nodes of different categories.

Like any neural network, a GNN is composed of several layers. At each layer, each node updates its vector representation by collecting, transforming, and then aggregating the vector representations of its immediate neighbors. A node might, for instance, perform a linear transformation on the features of its neighboring nodes, sum the resulting products, then apply a LeakyReLU activation as a final step. The result will be that node’s updated vector representation, which will be applied to the next layer of the GNN. Formulated more mathematically, given the vector embedding 𝒽ᵥˡ of node v at layer of our GNN and that node’s local neighborhood 𝒩(v), embedding of node v at the next layer of the GNN will be:

In the above update equation, is a learnable weight, 𝒽ᵤˡ is the embedding of one of node v’s neighbors u at layer , and σ is a non-linear activation function such as LeakyReLU. We repeat this process until we are satisfied with the expressivity of our model. We note that 𝒽ᵥ⁰ is simply the initial node features we describe above for node v.

Figure 2: Visualization of message passing, aggregation, and update for a single node in a GNN [7]

Graph Attention Networks and Edge Features

What we’ve just explained is how a simple, vanilla GNN might work. In our trade-flow problem, we make use of a Graph Attention Network (GAT) [8]. With GATs, we add two new features that provide our model with even more expressivity. First, we allow each node to decide which of its neighbors is more important; we do this through an attention parameter that gets attached to each of node v’s neighbors. Like Wᵏ, this attention weight is learnable. Adding this attention parameter to the above update equation, we get:

where our learnable attention parameter is αᵤᵥ (the attention weight from u to v). For more information on how this attention parameter is computed, you can check out the original paper.

The second exciting feature that GATs incorporate are edge features. This is quite useful, as our trade-flow dataset documents not just which countries trade with one another, but what and how much they trade. In particular, we can look at the top 10 most traded items globally in our data and the quantity of these items traded between two countries. Using GATs, we can include such trade statistics and encode them as our edge features. Exactly how these edge features are incorporated into the model is beyond the scope of this post, but the interested reader can learn more by reading Pytorch Geometric’s GAT documentation!

Figure 3: Graph Attention Networks (GATs) incorporate local neighborhood information when computing node embeddings. To compute the embedding hᵥˡ⁺¹ of node v at layer , node v linearly transforms the embeddings of its four neighbors using the weight , applies a per-neighbor attention mechanism αᵤᵥ, aggregates the results, then applies a non-linear activation σ.

GATs for GDP Prediction

Now that we’ve got the basics out of the way, let’s move on to implementing a GAT for our GDP Prediction task. We will implement the following pipeline:

Figure 4: Model architecture pipeline. Note the final ReLU is to enforce that GDP is non-negative.

First, let’s load in and pre-process our dataset by reading in the BACI trade flow and our World Bank datasets and extracting our edge features and node features respectively:

Next, we use PyG’s GATConv model to build a simple, two-layer graph convolutional network, followed by a single linear layer. This ‘baseline’ model (i.e. GAT without incorporating any edge features) only relies on the individual node features and trade-flows of each country.

As a comparison, we’ll build a separate model that does incorporate edge features. We’ll put these models head-to-head and see which one performs the best:

Now, we train both models and compare their performances:

Figure 5: Comparing log(loss) values for the GAT model vs the baseline. The GAT model edges out the baseline loss
Figure 6: Model predictions increase in accuracy after 2000 epochs, showing trends of predicting GDP accurately

Results

First, it’s clear that both models are learning over time. As shown in Figure 5, the MSE on the validation dataset decreases rather smoothly until about 2000 iterations. This indicates that learning signal is correctly being passed through the model and it can generalize to the unseen validation set. Note that the y-axis is log-scaled to bring out the small but steady gains that the model is making.

A second observation is that the model with edge features slightly but consistently outperformed the baseline model without edge features. The validation MSE, which is calculated on the same dataset for both models, is lower for the model with edge features, indicating that our edge features, which are quantities of goods traded, add information about the size of a country’s economy.

That’s the good news. Let’s move on to examine some shortcomings of the model and discuss ways to improve it further. First, the log scaling of the labels may make the MSE seem more impressive than it actually is. At the end of training, the MSE on the validation set is around 5.4 for the model with edge features. This means that each prediction is in expectation about 2.3 off the true value. While it would be incredible if the model were $2.30 off the true GDP of a country, recall that we log-scaled our data, so the loss means that the model is about 2.3 orders of magnitude off the correct value. The upshot is that we’re not going to be putting the World Bank out of a job anytime soon.

The second point comes out when we analyze how the model improves its accuracy. Consider Figure 6, which shows the model’s predictions vs. actual values on the test data at initialization and after 2000 epochs of training. These two plots show that much in the gains in the model’s accuracy can be attributed to simply learning the mean GDP and shifting all predictions up by a commensurate amount. The good news is that there is a positive correlation between the predicted values and the actual values of about .27, indicating that the predictions are incorporating additional information beyond the mere mean of the world’s GDP, but this figure further underscores that we should not declare victory due to the seemingly low MSE.

What to do next? We tried training the model beyond 2000 iterations, but validation accuracy started to decrease, indicating that we were probably overfitting. We also tried different combinations of hyperparameters and model architectures to no significant effect. The simplest and probably best prescription would be to add additional features, either at the node or edge level. While one would run the risk of overfitting if the dimensions of our features outrun the size of our dataset, it would probably be accretive to model accuracy. Followup work to this could investigate the effects of additional features and their dimensionality on the performance here.

Conclusion

We learned how to implement Graph Attention networks, a powerful architecture that allows nodes to learn how important their neighbor’s messages are. Further we improved the GAT’s performance by including edge features, which encoded information about the quantity of trade of certain goods between two countries. Use of edge features allows further expressivity of our model and exploits additional patterns that may be present in our data. They can be applied not only to our trade network, but also to many other applications as well.

[1] BACI: International Trade Database at the Product-Level. The 1994–2007 Version CEPII Working Paper, N°2010–23, Octobre 2010 Guillaume Gaulier, Soledad Zignago, http://www.cepii.fr/CEPII/en/bdd_modele/presentation.asp?id=37

[2] GDP 1995–2019, The World Bank, Dec 2021, https://data.worldbank.org/indicator/NY.GDP.MKTP.CD. Accessed 1 Dec. 2021. Dataset.

[3] Population 1995–2019, The World Bank, Dec 2021, https://data.worldbank.org/indicator/SP.POP.TOTL. Accessed 1 Dec. 2021. Dataset.

[4] Unemployment 1995–2019, The World Bank, Dec 2021, https://data.worldbank.org/indicator/SL.UEM.TOTL.ZSAccessed 1 Dec. 2021. Dataset.

[5] CPI 1995–2019, The World Bank, Dec 2021, https://data.worldbank.org/indicator/FP.CPI.TOTL.ZG. Accessed 1 Dec. 2021. Dataset.

[6] “An Information Flow Model for Conflict and Fission in Small Groups.” Journal of Anthropological Research, vol. 33, no. 4, [University of New Mexico, University of Chicago Press], 1977, pp. 452–73, http://www.jstor.org/stable/3629752.]

[7] From Lecture 6, Slide 59 of Professor Jure Leskovec’s Graph Machine Learning Course, CS 224W

[8] Veličković, Petar et al. “Graph Attention Networks”. 2018. Web.

--

--