Rolling with ROLAND- from Static to Dynamic GNN

In this post we investigate the performance of the ROLAND model which generalizes any static GNN to dynamic graphs. Our results improve on the task of link prediction on Bitcoin datasets, and our Google Colab Notebook can be accessed here.

Miria Feng
Stanford CS224W GraphML Tutorials
13 min readApr 26, 2023

--

By Miria Feng, Andrew Cheng, and Yujie Gao as part of the Stanford CS224W course project.

Introduction

Graph Neural Networks (GNN) are an exciting frontier of Deep Learning, and have successfully modeled a vast array of problems across numerous domains. Applications vary from protein folding, molecular design, traffic prediction, to query-answer systems! However, although much progress has been made on static GNNs, many interesting real world problems are dynamic and evolve over time.

Visualization of retweeting behavior of Twitter influencers over time by https://towardsdatascience.com/animate-dynamic-graphs-with-gephi-d6bd9faf5aec

In this blog post we investigate ROLAND¹: a novel Graph Learning Framework for Dynamic Graphs implemented on GraphGym¹¹, which is able to extend any static GNN architecture to a dynamic setting. Dynamic graph machine learning techniques opens the door to a new exciting class of GNNs that are able to handle dynamic data and capture their evolution through time.

Overview of the roadmap ahead in this post:

  • Why Dynamic graphs? Static vs dynamic
  • Introducing ROLAND
  • Dynamic datasets and link prediction task
  • Three methods of extending ROLAND and code walkthrough
  • Insights, discussion, and future work

Why Dynamic graphs? Static vs dynamic

Unlike static graphs, dynamic graphs change over time. Therefore in some cases node and edge sets may completely differ over time, presenting a more challenging setting. Examples of practical scenarios in the real world where graphs dynamically evolve include: social interactions, dynamic networks of financial transactions, recommender systems where nodes and edges of the graph appear or change over time, and traffic updates as new accidents or road closures are reported. It is the dynamic behavior of these real world models that convey important insights, which are otherwise lost in a static setting. So how do we extend successful static GNN models to a dynamic graph? ROLAND provides an elegant solution.

Introducing ROLAND

ROLAND is an experimental model developed in 2022 which aims to provide a flexible learning framework for dynamic graphs by making it possible to repurpose any static graph in a dynamic graph setting. Since the traditional static graph representations have inspired numerous creative designs for architecture such as skip connections, attention, and batch normalization, the natural motivation is to extend these structures to dynamic problems. ROLAND accomplishes this by viewing node embeddings at different GNN layers as hierarchical node states, and then defines how to recurrently update these states over time as newly observed nodes and edges are introduced. This module will update node embeddings hierarchically and dynamically, and can be inserted to any static GNN. As a proof of concept, ROLAND explores adding skip connections to dynamic GNN to great effect, and is able to achieve an average 15.5% relative mean reciprocal rank (MRR) improvement over its baseline models.

Dynamic Datasets and Link Prediction Task

Dynamic data for GNN differ from traditional static GNN data since nodes and edges gradually appear over time. Therefore in addition to traditional features, each node and edge further has a timestamp feature. This is illustrated in the figure below.

There is an additional time feature associated with nodes and edges in dynamic graphs¹

The data is then viewed as a stream of snapshots distributed in a frequency, before being deterministically split over time during training. For example, in our setting the first 8 months of data will be used as training, then one month for validation, then the last month as the test set. Our experiment setting is consistent with ROLAND’s and uses the snapshot based representation of dynamic graphs where nodes and edges arrive in batches over time. The following shows a visualization of the training procedure with snapshots arriving in sequence.

Visualization of training procedure as snapshots arrive in sequence

The big difference between static and dynamic GNN is that in the dynamic scenario, we represent the input as a time varying sequence of graph snapshots, where each individual snapshot is one static graph. In this way we can handle node addition/deletion, since graph snapshots can have different sets of nodes. The table below summarizes the datasets used for experiments in this blog post.

BitcoincOTC and BitcoinAlpha datasets

We utilize these two bitcoin datasets on a link prediction task in order to benchmark against prior results from ROLAND. BitcoincOTC and BitcoinAlpha contain who-trusts-whom networks of people who trade on OTC and Alpha platforms respectively.

The link prediction task in a dynamic setting can then be formulated as follows: At each time t, the model utilizes information accumulated up to time t and predicts edges in snapshot t+1. We use mean reciprocal rank (MRR) instead of ROC AUC to evaluate performance since negative labels significantly outnumber positive labels in our task. For each node u with positive edge (u, v) at t+1 we randomly sample 1000 negative edges emitting from u to identify the rank of edge (u, v)’s prediction score among all other negative edges. The MRR score is then defined as the mean reciprocal ranks over all nodes u. We consider 2 different train-test split methods: fixed split method evaluates models using all edges from the last 10% of snapshots. Live update method evaluates model performance over all available snapshots. We randomly chose 10% of edges in each snapshot to determine the early-stopping condition.

Three Methods of extending ROLAND and Code Walkthrough

In order to extend ROLAND’s results we view the machine learning for graphs problem as two components: first model architecture aim to produce smoother loss functions on the optimization landscape, then optimizers are there to help you reach your goal on that loss landscape. Therefore we breakdown the task of improving dynamic graph learning problems by adding more expressive complexity to the baseline model architecture of ROLAND in two different ways, and explore one more customized optimization strategy for the task of link prediction.

  1. Dynamic Graphs with Attention

Although the ROLAND model has demonstrated the success of RNNs in link prediction tasks, it has been shown that RNNs struggle with long-range dependency. Therefore ROLAND will likely have trouble with dynamic graphs which involve many time-steps. The successful incorporation of the attention mechanism⁶ in static GNN motivates this extension to a dynamic setting, especially since transformer and attention-based architectures are parallelizable. This allows more flexibility in snapshot processing. We aim to experiment with attention such that the current graph snapshot can attend to all previous timesteps resulting in richer contextualized features. Analytically, this means the attention weights will be able to see which time step is the most meaningful, and offer a deeper interpretable insight into results.

We implement multi-head causal attention and treat consecutive graph snapshots in a batch then feed it into the model sequentially. The breakdown of our implementation process and diagram of the attention architecture is as follows:

  • Input: A batch of T dynamic graph snapshots each with node features of shape (N, E)
  • Run each graph snapshot in batch through GNN
  • Stack node features into (T, N, E)
  • Transpose to (N, T, E)
  • Do MHA using a causal attention mask to prevent looking into future
  • Transpose back to (T, N, E)
  • Unpack back to list of (N,E) for next GNN layer
  • Repeat K times to get contextualized K-hop info
Attention architecture incorporated into the dynamic GNN setting

Lastly we run through classifier head and compute Loss (predicted labels for (i)th graph, ground truth labels for (i+1)th graph). The code breakdown for our training implementation is below.

def train_step(model, optimizer, scheduler, dataset, t: int, batch_size: int = 10) -> dict:
"""
Given task to predict t+1, train the model on (t-batch_size, t] batch.
This is equivalent to using the last batch_size timesteps as context.
After receiving ground truth from a particular task, update the model by
performing back-propagation.

Ex) Train the model using G[t-batch_size+1], ..., G[t] for message passing and
corresponding labels: label[t-batch_size+2], ..., label[t+1] as target.
"""
optimizer.zero_grad()
torch.cuda.empty_cache()
model.train()
X, Y = get_task_batch_sequence_(dataset, t=t, batch_size=batch_size)
preds, labels = model(batch = X, targets = Y)
# Accumulate losses
total_loss = 0.0
for pred, label in zip(preds, labels):
loss, pred_score = compute_loss(pred, label)
total_loss+=loss
total_loss.backward()
optimizer.step()
scheduler.step()
#print(total_loss)
return {'loss': total_loss}

We implement the attention mechanism inside the GNNLayer class, as summarized in the following code block.

class GNNLayer(nn.Module):
"""
The most general wrapper for graph recurrent layer, users can customize
(1): the GNN block for message passing.
(2): the update block takes {previous embedding, new node feature} and
returns new node embedding.
"""
def __init__(self, dim_in, dim_out, has_act=True, has_bn=True,
has_l2norm=False, id=0, **kwargs):
super(GNNLayer, self).__init__()
self.has_l2norm = has_l2norm
self.layer_id = id
has_bn = has_bn and cfg.gnn.batchnorm
self.dim_in = dim_in
self.dim_out = dim_out
self.layer = ResidualEdgeConv(dim_in, dim_out, bias=not has_bn, **kwargs)
layer_wrapper = []
if has_bn:
layer_wrapper.append(nn.BatchNorm1d(
dim_out, eps=cfg.bn.eps, momentum=cfg.bn.mom))
if cfg.gnn.dropout > 0:
layer_wrapper.append(nn.Dropout(
p=cfg.gnn.dropout, inplace=cfg.mem.inplace))
if has_act:
layer_wrapper.append(nn.PReLU())
self.post_layer = nn.Sequential(*layer_wrapper)
self.mha = nn.MultiheadAttention(embed_dim=dim_out, num_heads=2, batch_first=True)
self.ln = nn.LayerNorm(dim_out)

def _stack_node_features(self, batch: List[deepsnap.graph.Graph]) -> torch.Tensor:
"""Given a list of deepsnap graphs, stack their node features,
each with shape [N, E] to [num_graphs, N, E]
"""
return torch.stack([g.node_feature for g in batch], dim=0)
def _reassign_node_features(self, batch: List[deepsnap.graph.Graph], node_feats: torch.Tensor
) -> List[deepsnap.graph.Graph]:
"""Update node_features for each graph and turn into
a batched list of deepsnap graphs again
"""
for i,g in enumerate(batch):
batch[i].node_feature = node_feats[i]
return batch

def forward(self, batch: List[deepsnap.graph.Graph]):
# Message passing for all graphs in batch.
for i,g in enumerate(batch):
batch[i] = self.layer(g)
batch[i].node_feature = self.post_layer(g.node_feature)
if self.has_l2norm:
batch[i].node_feature = F.normalize(g.node_feature, p=2, dim=1)

# Stacks node_features of batch_size graphs
# into tensor with shape [N, batch_size, embed_size]
node_feats = self._stack_node_features(batch).transpose(0,1)
#print(node_feats.shape)
residual = node_feats
# Multihead causal attention: causal attention mask so can only see past
# Q: N X batch_size, embed_size
# K: N X batch_size, embed_size
# V: N X batch_size, embed_size
node_feats, attention_weights = self.mha(
query=node_feats,
key=node_feats,
value=node_feats,
is_causal=True
)
# Residual connection + Layer norm according to Vaswani et al.
node_feats = self.ln(residual + node_feats).transpose(0,1)
# Updates node_features of the batch list for next message passing
batch = self._reassign_node_features(batch, node_feats)
return batch

2. Varying Embedding Update Methods

ROLAND offers proof of concept that generalizing a static GNN to a dynamic setting can be defined by an update method for hierarchical node states as new nodes and edges appear. The authors experiment with three embedding update methods: Moving Average, MLP, and a Gated Recurrent Unit (GRU). We replicate the experiment results and extend ROLAND’s embedding update methods to get a better understanding of the latent embedding space, and its impact on results.

The current moving average update function updates the new state using a weighted sum of historical node states and lower layer node states. It calculates the average of a range of prices by the number of periods within that range and naturally capture the dynamics of embedding. To extend the updating method, we implement with Exponential Moving Average² (EMA), which is more popular in real world financial problems. In comparison to the simple moving average, EMA gives a higher weighting to recent prices and is more reactive to the latest price changes as new data comes in.

Since the strongest results currently reported by ROLAND use GRU as an embedding update method, naturally the next step will be to investigate the impact of using LSTM³ which is more effective at storing long-term dependencies due to its added complexity. LSTM has three gates and a cell state to store information, while the simpler GRU only uses two gates (an update gate and a reset gate). Since the task should be sensitive to retaining long-term dependencies, LSTM is promising due to its additional cell state for storage⁴. The following code block summarizes our implementation of the more expressive LSTM cell.

# LSTM gates structure implementation 
# forget gate
self.LSTM_F = nn.Sequential(
nn.Linear(dim_in + dim_out, dim_out, bias=True),
nn.Sigmoid())
# input gate
self.LSTM_I = nn.Sequential(
nn.Linear(dim_in + dim_out, dim_out, bias=True),
nn.Sigmoid())
# update cell state
self.LSTM_C = nn.Sequential(
nn.Linear(dim_in + dim_out, dim_out, bias=True),
nn.Tanh())
# output gate
self.LSTM_O = nn.Sequential(
nn.Linear(dim_in + dim_out, dim_out, bias=True),
nn.Sigmoid())
self.LSTM_H = nn.Tanh()

The output is then computed via the LSTM module⁵ via the following code.

        # Compute output from LSTM module
C_prev = batch.node_cells[self.layer_id]
F = self.LSTM_F(torch.cat([X, H_prev], dim=1))
I = self.LSTM_I(torch.cat([X, H_prev], dim=1))
O = self.LSTM_O(torch.cat([X, H_prev], dim=1))
C_tilde = self.LSTM_C(torch.cat([X, H_prev], dim=1))
C = F * C_prev + I * C_tilde
H_LSTM = O * self.LSTM_H(C)

3. Optimization with Nesterov Momentum

So far considerable efforts have explored developing deeper more complex architectures for dynamic GNN such as incorporating batch normalization, attention, and residual connections. These modifications can produce a smoother loss function, but developing a better optimization technique to deal with these non-convex loss function graphs is also important to achieve better results.

We have explored extending static graph architecture with attention and varying node embedding update methods, but from the perspective of Deep Learning as an optimization problem we propose that the optimizer used to reach our goal can be improved. Currently most GNNs default to the Adam optimizer since it is universally the most popular and generally works well. However dynamic GNNs deal with time varying data, therefore it seems natural to investigate better ways of learning with time history as a consideration.

Momentum optimization⁹ methods tend to derive their motivation from modeling the optimization problem as an ordinary differential equation (ODE). This class of methods keep some time history of the optimization trajectory by mimicking Newton’s second law of physics, and view the parameter search procedure as a particle moving through the loss landscape aiming for the global minimal point. The following animation visualizes momentum versus gradient descent without momentum.

From: https://paperswithcode.com/method/sgd-with-momentum

The idea being: instead of always using the current search direction, we should be incorporating the effects of prior search directions, thus making optimization methods loss localized and hopefully more robust. This seems to provide an intuitive link to GNNs which are intimately related to differential equations. Therefore we are motivated to incorporate the Nesterov Momentum technique in our dynamic GNN setting, which improves on the classic momentum method by changing how it updates the gradient. Nesterov first uses a forward Euler predictor step using velocity (the first derivative) instead, and the gradient is then evaluated at this new location. A brief summary of our approach to custom optimizer implementation is shown below, please refer to the Colab link above for the full class.

def create_optimizer(params):
params = filter(lambda p: p.requires_grad, params)
# Try to load customized optimizer
for func in register.optimizer_dict.values():
optimizer = func(params)
if optimizer is not None:
return optimizer
if cfg.optim.optimizer == 'adam':
optimizer = optim.Adam(params, lr=cfg.optim.base_lr,
weight_decay=cfg.optim.weight_decay)
elif cfg.optim.optimizer == 'sgd':
optimizer = optim.SGD(params, lr=cfg.optim.base_lr,
momentum=cfg.optim.momentum,
weight_decay=cfg.optim.weight_decay)
# create custom optimizers, specify each calculation in a separate class
elif cfg.optim.optimizer == 'nadam_custom':
optimizer = Nadam(params, lr=cfg.optim.base_lr,
momentum_decay=0.004,
weight_decay=cfg.optim.weight_decay)
# add Adam with Nesterov momentum
elif cfg.optim.optimizer == 'nadam':
optimizer = optim.NAdam(params, lr=cfg.optim.base_lr,
momentum_decay=0.004,
weight_decay=cfg.optim.weight_decay)

else:
raise ValueError('Optimizer {} not supported'.format(
cfg.optim.optimizer))

return optimizer

The popular default method currently used in ROLAND is the Adam⁸ optimizer, which aims to combine ideas from adaptive learning rate and momentum. This makes Adam especially suitable for handling sparse gradients on complex problems with a large number of features such as GNNs. However recent works have shown that although Adam converges more quickly than SGD with momentum, it does converge quicker to a worse solution. In our dynamic GNN setting, SGD with momentum is not practical since we are working in high dimensions in these problems with huge datasets. Therefore some combination of Adam with Nestorov momentum is the natural choice, which inspires our implementation of the Nadam optimizer⁷. The equations below outline the difference between Adam vs Nadam, and numerical results are summarized in the next section.

Comparison of Adam (left) versus Nadam (right). Note the additional momentum decay variable in Nadam.¹⁰

Insights, Discussion, and Future work

ROLAND provides an easy way to extend any static GNN to a dynamic graph setting, but there is exciting room for experimentation and improvement! We investigate the capabilities and limitations of ROLAND on the link prediction task of forecasting future transactions on two bitcoin datasets. Our results show that although the dynamic GNN generalization problem is defined by the embedding node update method, varying the expressiveness and complexity of the update modules actually had very little effect on results. This leads us to conclude that ROLAND’s strategy of modeling dynamic graphs is particularly robust to changes in embedding techniques. The successful incorporation of the attention mechanism in the dynamic graph setting is extremely promising, and invites further work to explore the effect of varying dynamic GNN data processing in order to take advantage of the attention mechanism. The following table summarizes the results.

Table of summary MRR results averaged across 3 runs, benchmarked against ROLAND

Finally the addition of the Nadam optimization technique, which incorporates Nesterov momentum and Adam, was able to achieve an approximate 30% MRR improvement over current baselines in ROLAND. This offers meaningful insight into the optimization landscape of the dynamic GNN problem, and suggests that modeling the optimization procedure closer to a physics or ODE problem might achieve significant improvements. This also leads us to conclude that the loss landscape in ROLAND most likely presents more local minima and is less smooth than standard GNN models. Promising directions for future work include fine-tuning the optimization method in dynamic graphs to be closer to a particle moving in the optimization landscape, and adjusting the formulation of momentum in the underlying mathematical techniques. Experimenting with larger financial datasets to take advantage of the attention architecture will also be of particular interest.

We hope this blog post has provided the reader with an overview of dynamic GNNs and ROLAND, as well as the exciting potential to solve real world problems with machine learning techniques!

¹Jiaxuan You, Tianyu Du, and Jure Leskovec. Roland: Graph learning framework for dynamic graphs. Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining, 2022.

²Frank Klinker. Exponential moving average versus moving exponential average. Math. Semesterber, 2020.

³Junyoung Chung, Çaglar Gülçehre, Kyunghyun Cho, and Yoshua Bengio. Empirical evaluation of gated recurrent neural networks on sequence modeling. ArXiv, abs/1412.3555, 2014.

⁴Cho, K., Van Merriënboer, B., Gulcehre, C., Bahdanau, D., Bougares, F., Schwenk, H. and Bengio, Y.. Learning phrase representations using RNN encoder-decoder for statistical machine translation. arXiv preprint arXiv:1406.1078, 2014.

⁵Hochreiter S, Schmidhuber J. Long short-term memory. Neural computation. 1997 Nov 15;9(8):1735–80.

⁶Vaswani, A., Shazeer, N.M., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). Attention is All you Need. ArXiv, abs/1706.03762

⁷Dozat, T. ICLR (2016). Incorporating Nesterov Momentum into Adam. https://openreview.net/pdf/OM0jvwB8jIp57ZJj

⁸Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980, 2014.

⁹Ilya Sutskever, James Martens, George Dahl, and Geoffrey Hinton. On the importance of initialization and momentum in deep learning. In Proceedings of the 30th International Conference on Machine Learning (ICML-13), pp. 1139–1147, 2013.

¹⁰ https://pytorch.org/docs/stable/generated/torch.optim.NAdam.html

¹¹https://github.com/snap-stanford/GraphGym

--

--