Adding subgraph support in PyG for more expressive GNNs

Zhiqi Li
Stanford CS224W GraphML Tutorials
18 min readMay 14, 2023

By Matúš Jurák, Zhiqi Li, and Sofian Zalouk as part of the Stanford CS224W course project.

Deep learning on graphs is incredibly useful and appealing in a world full of data in the form of networks — large directed or undirected graphs. In contrast with the traditional machine learning methods that work well with information encoded as sequences or grids, machine learning on graphs can tackle problems that handle data with more complex structures and dependencies. Message passing neural networks (MPNNs) are a class of graph deep-learning models composed of several layers which perform message passing and aggregation across a node’s neighbors.

Figure 1. MPNN diagram taken from Stanford CS224W Lecture slide.

MPNNs are appealing due to their locality of computations, since they only require the 1-hop neighbors of a node at each layer. Furthermore, they are scalable since their complexity scales linearly in the number of edges in the graph. Altogether, the locality and scalability of MPNNs have made them the leading deep-learning architecture for graph-structured data.

However, the locality of MPNNs limits their expressive power. By expressive power, we mean whether the local information of the graphs are sufficient for us to distinguish non-isomorphic graphs, i.e. determine whether two graphs are topologically equivalent. It has been shown that MPNNs are at most as expressive as the Weisfeiler-Lehman (WL) graph isomorphism test [9], an algorithm that attempts to answer the graph isomorphism question in polynomial time.

The WL test takes a graph and assigns a color to each distinct rooted subtree of the graph. In particular, nodes that have identical rooted subtree structures will be assigned the identical color. This node-coloring process is done iteratively. Given the color assignments of two graphs, the WL test compares the counts of distinct colors of the node colorings for the two graphs.

Figure 2: An example of two non-isomorphic graphs indistinguishable by the WL test.

However, it is possible to have non-isomorphic graphs with identical rooted subtree structures. As a result, the node colorings by the WL test will also be the same. For example, consider the two graphs in Figure 2. Both graphs have node colorings with three colors with 2, 2, 4 nodes per color. Despite being non-isomorphic, these simple graphs are indistinguishable by the WL test. To that end, an important research problem becomes:

How can we improve the expressive power of MPNNs, which is limited to the expressive power of the WL test?

Increasing the expressive power of GNNs using subgraphs

We’d like to draw your attention to a particularly interesting class of methods uses subgraphs to create more expressive Graph Neural Networks (GNNs).

Figure 3: after removing one edge (red, dotted) from each graph, the two graphs are now distinguishable by the WL test

Consider again the two graphs in Figure 2, which are indistinguishable by the WL test. However, if we minimally perturb the input graphs by removing a single edge from each graph, as shown in Figure 3, where the removed edges are denoted by the red, dotted lines, we obtain two subgraphs which are now distinguished by the WL test. In particular, the graph on the left has 7 distinct colors, while the graph on the right only has 8 distinct colors. At their core, subgraph methods are founded on the idea that it is possible to perturb graphs to make them distinguishable by the WL test, thereby increasing the expressive power of GNNs.

There are several recent works which have explored subgraph GNNs, many of which have been surveyed in [3]. For one, DropoutGNNs [5] are a simple, yet surprisingly effective, method which randomly removes nodes from the computation graph at each layer of the MPNNs. By doing so, each node sees a different neighborhood of the graph, thereby implicitly creating different subgraph embeddings.

Another interesting approach is to extend the WL test to a subgraph-based variant [10]. In the WL test, the iterative node-coloring process only considers the given root node and its neighbors at each step. By extending stars (subtrees formed by root nodes with their neighbors) to subgraphs (the k-hop neighborhoods of root nodes), as shown in Figure 3, the power of the graph isomorphism test greatly improves, even for small values of k.

Figure 3: Start2Subgraphs GNN taking 1-hop subgraphs for every node (taken from [10])

For our project, we will be focusing on a more recent subgraph GNN method, Equivariant Subgraph Aggregation Networks (ESAN), proposed in [2]. We present an implementation of the ESAN model as a pull request to the PyG torch_geometric library. In this blog, we will give a high-level explanation of the ESAN framework and a walk-through of our implementation.

Understanding Equivariant Subgraph Aggregation Networks (ESAN)

Various subgraph GNNs differ in their methods of subgraph processing and data aggregation. We are drawn to ESAN as its methods rely on the symmetry structure of the subgraphs. Similarly to many subgraph GNNs, ESAN uses a subgraph selection step that maps an input graph G to “a bag of subgraphs,” encoding the original graph data as a multi-set of subgraphs. ESAN proposes a symmetry group on the multi-set as the direct product of subgraph and node permutations. Here the node permutation acts on the nodes of all subgraphs in the multi-set simultaneously, so it can be viewed as an aggregate step in the GNN framework.

Figure 5: (left panel) model architectures of DSS-GNN and DS-GNN; (right panel) the structure of H-Equivariant layers for DSS-GNN and DS-GNN.

ESAN introduces two different model architectures: DSS-GNN and its variant DS-GNN. Both models consist of the following three components (see the left panel of Figure 5):

Equation 1

The first component E_subgraphs is an equivariant feature encoder comprising of several layers that reflect the symmetric structure of the subgraphs, which are called H-equivariant layers. These layers map bags of subgraphs to bags of subgraphs. Specifically, ESAN makes use of two different graph encoders in each H-equivariant layer L: (1) a Siamese network L¹ that processes each subgraph in the multi-set independently, and (2) an information sharing module L² that allows for information sharing across all subgraphs. Each layer is defined as follows for the i-th subgraph with adjacency matrix A_i and feature matrix X_i.

Equation 2

where A and X denote the adjacency matrices and features of the bags of m subgraphs.

By enabling or disabling the information sharing L² module, as illustrated in the right panel of Figure 5, we have either:

  1. DSS-GNN which consists of both the siamese network and the information sharing module,
  2. or DS-GNN which only consists of the siamese network, i.e. each H-equivariant layer has no L² component:
Equation 3

Given the output of the graph encoder, the second component — readout layer R_subgraphs aggregates the subgraph data independently and returns an invariant feature vector for each subgraph. Finally, the thir component — set encoder E_sets encodes the set of invariant feature vectors into one invariant vector for the original graph G.

The ESAN architectures have no restrictions on subgraph selection policies, i.e. the map from the graph to its bag of subgraphs. In fact, the paper [2] investigates the performance of the ESAN models using several different subgraph selection policies, including node-deletion, edge-deletion, ego-networks.

Implementing ESAN in PyG

In our implementation, we follow the PyG contribution guidelines and use the official author implementation for ESAN for reference. We try to use as much existing PyG functionalities as possible and make our new functionalities fully integrated with the PyG framework. As explained in the earlier section, there are two main components of the ESAN framework: (1) Subgraph selection policy, and (2) A Bag-of-Graphs Encoder architecture. We now present a walk-through of our implementation for these two components.

Subgraph Selection Policies

We aim to provide support for all subgraph selection policies discussed in [2], namely node-deletion, edge-deletion, and ego-networks. We plan to contribute the subgraph selection policies to the torch_geometric.transforms package.

In order to do so, we create a base class SubgraphPolicy for all subgraph selection policies via PyG’s abstract base class for transforms torch_geometric.transforms.BaseTransform. A SubgraphPolicy object can be passed in the parameter pre_transfrom in PyG Dataset classes, which takes in a PyG Data object and returns its transformed version before being saved to disk.

When called, a SubgraphPolicy transforms the given Data object to a SubgraphData object. We implemented the SubgraphData class using the base class Data from torch_geometric.data with additional parameters encoding important information of the transformed data. Via the from_data_to_list method of torch_geometric.data.batch, the information of the list of generated subgraphs is encoded in one SubgraphData object, which represents the multi-set of a bag-of-subgraphs.

Here is a code snippet for our implementation of SubgraphData and SubgraphPolicy.

Note: in order to run the following code snippets, make sure the required packages have been installed and imported. Please refer to our Github repo for more details.

class SubgraphData(Data):
""" A data object describing a collection of subgraphs generated
from a given subgraph selection policy.
It has several additional (**) properties:
From Data
* x
* edge_index
* edge_attr
* y
* pos
Additional
** :obj:`subgraph_id` (Tensor): The indices of the subgraphs
** :obj:`subgraph_batch` (Tensor): The batch vector of the subgraphs
** :obj:`subgraph_n_id` (Tensor): The indices of nodes in the subgraphs
** :obj:`orig_edge_index` (Tensor): The edge index of the original graph
** :obj:`orig_edge_attr` (Tensor): The edge attribute of the original graph
** :obj:`num_subgraphs` (int): The number of generated subgraphs
** :obj:`num_nodes_per_subgraph` (int): The number of nodes in the graph
"""
def __inc__(self, key, value, *args, **kwargs) -> Any:
if key == 'orig_edge_index':
return self.num_nodes_per_subgraph
elif key == 'subgraph_batch':
return 0
else:
return super().__inc__(key, value, *args, **kwargs)


class SubgraphPolicy(BaseTransform):
def __init__(self, subgraph_transform: Optional[Any] = None):
self.subgraph_transform = subgraph_transform

def graph_to_subgraphs(self, data: Data) -> List[Data]:
raise NotImplementedError

def __call__(self, data: Data) -> SubgraphData:
assert data.is_undirected()

subgraphs = self.graph_to_subgraphs(data)
if self.subgraph_transform is not None:
subgraphs = [
self.subgraph_transform(subgraph) for subgraph in subgraphs
]
subgraph_batch = Batch.from_data_list(subgraphs)

# Batch subgraphs data
out = SubgraphData(x=subgraph_batch.x, y=data.y,
edge_index=subgraph_batch.edge_index,
edge_attr=subgraph_batch.edge_attr,
subgraph_batch=subgraph_batch.batch,
subgraph_id=subgraph_batch.subgraph_id,
subgraph_n_id=subgraph_batch.subgraph_n_id,
orig_edge_index=data.edge_index,
orig_edge_attr=data.edge_attr,
num_nodes_per_subgraph=data.num_nodes,
num_subgraphs=len(subgraphs))

return out

Node Deletion
The idea of the Node Deletion policy is very simple. Each subgraph is generated as an induced subgraph by excluding a single node from the original graph. Similar ideas have been considered in Dropout GNNs [5]. For our implementation, we consider all the subgraphs induced by removing exactly one node. We implement transform class NodeDeletion using the base class SubgraphPolicy:

class NodeDeletionPolicy(SubgraphPolicy):
r"""Performs a node-level deletions creating
bag of subgraph for given graph. Each subgraph
is generated as an induced subgraph by
excluding a single node from the original graph.
Args:
subgraph_transform (Optional[Any]): An argument for transform
objects that allow additional pre-transform on the subgraphs
before the subgraphs are generated (e.g. if =Constant(),
it adds a constant feature to each subgraph;
in ESAN they used OneHotDegree for certain datasets too).
(default: :obj:`None`)
"""
def graph_to_subgraphs(self, data: Data) -> List[Data]:
r"""Generates subgraphs using a node-deletion policy.
Args:
data(Data): Input graph, from which we want to
create subgraphs.
"""
subgraphs = []
num_nodes = data.num_nodes
nodes_index = torch.arange(num_nodes)

for i in range(num_nodes):
subgraph_edge_index, subgraph_edge_attr = subgraph(
subset=torch.cat([nodes_index[:i], nodes_index[i + 1:]]),
edge_index=data.edge_index, edge_attr=data.edge_attr,
num_nodes=num_nodes)
subgraph_id = torch.tensor(i)
subgraph_data = Data(x=data.x, edge_index=subgraph_edge_index,
edge_attr=subgraph_edge_attr,
subgraph_id=subgraph_id,
subgraph_n_id=nodes_index,
num_nodes=num_nodes)
subgraphs.append(subgraph_data)

return subgraphs

Edge Deletion
The Edge Deletion policy considers all subgraphs generated by removing a single edge from the original graph. Since only undirected graphs were considered within the ESAN framework, we used the PyG methods to_undirected and filter_adj from torch_geometric.utils and torch_geometric.utils.dropout to ensure exactly one undirected edge is deleted from the original graph.

class EdgeDeletionPolicy(SubgraphPolicy):
r"""The Edge Deletion policy considers all
subgraphs generated by removing
a single edge from the original graph.
Since only undirected graphs were considered
within the ESAN framework only those are supported.
Args:
subgraph_transform(Optional[Any]): An argument for transform
objects that allow additional pre-transform on the subgraphs
before the subgraphs are generated (e.g. if =Constant(),
it adds a constant feature to each subgraph;
in ESAN they used OneHotDegree for certain datasets too).
(default: :obj:`None`)
"""
def graph_to_subgraphs(self, data: Data) -> List[Data]:
r"""Generates subgraphs using an edge-deletion policy.
Args:
data(Data): Input graph, from which we want to
create subgraphs.
"""
subgraphs = []
num_nodes = data.num_nodes
nodes_index = torch.arange(num_nodes)

head, tail = data.edge_index

# Handle the edge case of graph having no edges
if data.num_edges == 0:
subgraphs.append(
Data(x=data.x, edge_index=data.edge_index,
edge_attr=data.edge_attr, subgraph_id=torch.tensor(0),
subgraph_n_id=nodes_index, num_nodes=num_nodes))
return subgraphs

# Calling the method from torch_geometric.utils.dropout
head, tail, edge_attr = filter_adj(row=head, col=tail,
edge_attr=data.edge_attr,
mask=head < tail)

for i in range(head.size(0)):
subgraph_head = torch.cat([head[:i], head[i + 1:]])
subgraph_tail = torch.cat([tail[:i], tail[i + 1:]])
subgraph_edge_attr = torch.cat([
edge_attr[:i], edge_attr[i + 1:]
]) if edge_attr is not None else edge_attr
subgraph_edge_index = torch.stack([subgraph_head, subgraph_tail],
dim=0)

# The subgraph_edge_attr is None case is only required
# for PyG < 2.3.0.
# Otherwise to_undirected always returns a tuple
if subgraph_edge_attr is None:
subgraph_edge_index = to_undirected(
edge_index=subgraph_edge_index,
edge_attr=subgraph_edge_attr, num_nodes=num_nodes)
else:
subgraph_edge_index, subgraph_edge_attr = to_undirected(
edge_index=subgraph_edge_index,
edge_attr=subgraph_edge_attr, num_nodes=num_nodes)

subgraph_id = torch.tensor(i)
subgraph_data = Data(x=data.x, edge_index=subgraph_edge_index,
edge_attr=subgraph_edge_attr,
subgraph_id=subgraph_id,
subgraph_n_id=nodes_index,
num_nodes=num_nodes)
subgraphs.append(subgraph_data)

return subgraphs

We would also like to make a comment that the extension for directed graphs can be done very easily with a simpler implementation of removing one column at a time from edge_index and edge_attr.

EGO Networks
The ego-networks policy EGO of a specified depth k takes an input graph and generates a set of subgraphs induced by the k-hop neighborhoods of all nodes. Such a subgraph is also called a k-ego-network of a given node. The ESAN model also considers a variant of EGO, denoted by EGO+, that attaches an additional identifying feature to the root node of each subgraph.

We implement both EGO and EGO+ in the transform class EgoNetPolicy, which has two parameters num_hops (specifying the depth k) and ego_plus (specifying the chosen policy to be EGO+ or EGO). We take advantage of the PyG method k_hop_subgraph from torch_geometric.utils to obtain the k-hop neighborhood. Under the EGO+ policy, we add an additional 2D feature for the nodes, where the root node is identified by the vector [1,0] and the others by [0, 1].

class EgoNetPolicy(SubgraphPolicy):
r"""The EGO policy of a specified depth
k takes an input graph and generates a set of subgraphs
induced by the k-hop neighborhoods of all nodes.
Such a subgraph is also called a k-ego-network of a given node.
The ESAN model also considers a variant of EGO, denoted by EGO+,
that attaches an additional identifying
feature to the root node of each subgraph.
Args:
num_hops(int): Number of hops, used to create
neighborhood.
subgraph_transform(Optional[Any]): An argument for transform
objects that allow additional pre-transform on the subgraphs
before the subgraphs are generated (e.g. if =Constant(),
it adds a constant feature to each subgraph;
in ESAN they used OneHotDegree for certain datasets too).
(default: :obj:`None`)
ego_plus(Optional[bool]): Whetever to prepend
features to root node.
(default: :obj:`False`)
"""
def __init__(self, num_hops: int, subgraph_transform: Optional[Any] = None,
ego_plus: Optional[bool] = False):
super().__init__(subgraph_transform)
self.ego_plus = ego_plus
self.num_hops = num_hops

def graph_to_subgraphs(
self,
data: Data,
) -> List[Data]:
r"""Generates subgraphs using an rooted EGO/EGO+ policy.
Args:
data(Data): Input graph, from which we want to
create subgraphs.
"""
subgraphs = []
subgraph_data = None
num_nodes = data.num_nodes
nodes_index = torch.arange(num_nodes)

for i in range(num_nodes):
subgraph_id = torch.tensor(i)
_, subgraph_edge_index, _, edge_mask = k_hop_subgraph(
i, self.num_hops, data.edge_index, num_nodes=num_nodes)
subgraph_edge_attr = data.edge_attr[
edge_mask] if data.edge_attr is not None else data.edge_attr
subgraph_x = data.x

# add node features for EGO+ policy
if self.ego_plus:
# for the central node i, prepend a feature [1, 0]
# for all non-central nodes, prepend a feature [0, 1]

prepend_features = torch.tensor(
[[0, 1] if j != i else [1, 0]
for j in range(num_nodes)], ).to(
subgraph_edge_index.device, torch.float)
subgraph_x = torch.hstack([
prepend_features, subgraph_x
]) if subgraph_x is not None else prepend_features

subgraph_data = Data(x=subgraph_x, edge_index=subgraph_edge_index,
edge_attr=subgraph_edge_attr,
subgraph_id=subgraph_id,
subgraph_n_id=nodes_index,
num_nodes=num_nodes)
subgraphs.append(subgraph_data)

return subgraphs

It’s also important to note that after pre-processing data by calling SubgraphPolicy, batching would be done by specifying follow_batch= “subgraph_id”. Here’s a simple use case:

x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])

# Create a Batch object from the input data
data = Data(x=x, edge_index=edge_index)

# calling the helper method that takes the policy name
# and returns a subgraph policy
transform = subgraph_policy('node_deletion', num_hops=1)
subgraphs = transform(data)

ls_subgraphs = [subgraphs] * [1, 2, 5]
batched_data = Batch.from_data_list(ls_subgraphs,
follow_batch=['subgraph_id'])

Encoder architecture

For model architectures, we plan to contribute our code to the torch_geometric.nn package. We will implement new classes DSSnetwork for DSS-GNN and DSnetwork for DS-GNN, which will both inherit from torch.nn.Module, and also will directly implement the PyTorch layers and functions required for these models.

As we’ve discussed earlier, the DSS-GNN model and DS-GNN model both comprise of three components: (1) an equivariant feature encoder consisting of H-equivariant layers, (2) a subgraph readout layer, (3) a set encoder. The key difference between the two models is an additional information-sharing component in the H-equivariant layers in the DSS-GNN model. Therefore we implemented the DSnetwork class for DS-GNN and used it as a base class for DSSnetwork.

We implement the subgraph readout layer as subgraph_pool, which takes the node features h_node, the batched subgraph data batched_data and a pooling function pool to generate invariant feature vectors for each subgraph independently by the pool of its node representations:

def subgraph_pool(h_node, batched_data, pool):
# Represent each subgraph as the pool of its node representations
num_subgraphs = batched_data.num_subgraphs
tmp = torch.cat([
torch.zeros(1, device=num_subgraphs.device, dtype=num_subgraphs.dtype),
torch.cumsum(num_subgraphs, dim=0)
])
graph_offset = tmp[batched_data.batch]
subgraph_id = batched_data.subgraph_batch + graph_offset

return pool(h_node, subgraph_id)

The H-equivariant layers in the subgraph feature encoder of DS-GNN and DSS-GNN are given by Equation 2. The siamese component shared by both models is represented by the GNN layers on the batched subgraph data, which are instantiated in DSnetwork. For DSS-GNN, we initiate additional GNNs for the information sharing component in DSSnetwork, which will be applied after summing node features of the subgraphs via torch_scatter.scatter. Our implementation is as follows:

class DSnetwork(torch.nn.Module):
r"""DeepSets (DS) Network from the `"Equivariant Subgraph Aggregation
Networks" <https://arxiv.org/abs/2110.02910>`_ paper.
:class:`DSnetwork` outputs a graph representation based on the aggregation
of the subgraph representations.
Both :class:`DSnetwork` and :class:`DSSnetwork` models comprise of
three components:
1. An equivariant feature encoder consisting of H-equivariant layers
2. A subgraph readout layer
3. A set encoder.
.. note::
For an example of using a pretrained DSnetwork variant, see
`examples/esan.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
esan.py>`_.
Args:
num_layers(int): number of graph neural network (GNN) layers
in_dim(int): input node feature dimensionality
emb_dim(int): hidden node feature dimensionality
num_tasks(int): number of prediction tasks
eature_encoder(torch.nn.Module): node feature encoder module
GNNConv(torch.nn.Module): graph neural network convolution module
"""
def __init__(self, num_layers: int, in_dim: int, emb_dim: int,
num_tasks: int, feature_encoder: torch.nn.Module,
GNNConv: torch.nn.Module):
super(DSnetwork, self).__init__()
self.emb_dim = emb_dim
self.feature_encoder = feature_encoder

gnn_list = []
bn_list = []

# Create num_layers layers of GNNs and batch normalization (BN) layers
for i in range(num_layers):
gnn_list.append(GNNConv(emb_dim if i != 0 else in_dim, emb_dim))
bn_list.append(torch.nn.BatchNorm1d(emb_dim))

# Save the GNNs and BN layers as module lists
self.gnn_list = torch.nn.ModuleList(gnn_list)
self.bn_list = torch.nn.ModuleList(bn_list)

# Final layers to produce output predictions
self.final_layers = torch.nn.Sequential(
torch.nn.Linear(in_features=emb_dim, out_features=2 * emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(in_features=2 * emb_dim, out_features=num_tasks))

def reset_parameters(self):
reset(self.gnn_list)
reset(self.bn_list)

def forward(self, batched_data):
# Unpack input batch data
x = batched_data.x
edge_index = batched_data.edge_index
edge_attr = batched_data.edge_attr

# Encode node features
x = self.feature_encoder(x)

# Apply GNN layers
for gnn, bn in zip(self.gnn_list, self.bn_list):
h1 = bn(gnn(x, edge_index, edge_attr))
x = F.relu(h1)

# Pool node features across subgraphs to obtain
# subgraph representations
h_subgraph = subgraph_pool(x, batched_data, global_mean_pool)

# Pool subgraph representations to obtain graph representation
h_graph = scatter(src=h_subgraph, index=batched_data.subgraph_id_batch,
dim=0, reduce="mean")

# Apply final layers and return output
return self.final_layers(h_graph)


class DSSnetwork(DSnetwork):
r"""Deep Sets for Symmetric elements (DSS) Network from the
`"Equivariant Subgraph Aggregation Networks"
<https://arxiv.org/abs/2110.02910>`_
paper. The key additional functionality of :class:`DSSnetwork` compared to
:class:`DSnetwork` is an additional information-sharing component in
the H-equivariant layers in the DSS-GNN model.
.. note::
For an example of using a pretrained DSSnetwork variant, see
`examples/esan.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
esan.py>`_.
Args:
num_layers(int): number of graph neural network (GNN) layers
in_dim(int): input node feature dimensionality
emb_dim(int): hidden node feature dimensionality
num_tasks(int): number of prediction tasks
eature_encoder(torch.nn.Module): node feature encoder module
GNNConv(torch.nn.Module): graph neural network convolution module
"""
def __init__(self, num_layers: int, in_dim: int, emb_dim: int,
num_tasks: int, feature_encoder: torch.nn.Module,
GNNConv: torch.nn.Module):
super().__init__(num_layers=num_layers, in_dim=in_dim, emb_dim=emb_dim,
num_tasks=num_tasks, feature_encoder=feature_encoder,
GNNConv=GNNConv)

gnn_sum_list = []
bn_sum_list = []

# Initialize GNNs for data sharing module
for i in range(num_layers):
gnn_sum_list.append(GNNConv(emb_dim if i != 0 else in_dim,
emb_dim))
bn_sum_list.append(torch.nn.BatchNorm1d(emb_dim))
self.gnn_sum_list = torch.nn.ModuleList(gnn_sum_list)
self.bn_sum_list = torch.nn.ModuleList(bn_sum_list)

def reset_parameters(self):
reset(self.gnn_list)
reset(self.bn_list)
reset(self.gnn_sum_list)
reset(self.bn_sum_list)

def forward(self, batched_data):
# Unpack input batch data
x = batched_data.x
edge_index = batched_data.edge_index
edge_attr = batched_data.edge_attr
batch = batched_data.batch

# Encode node features
x = self.feature_encoder(x)

# Apply GNN layers
for i in range(len(self.gnn_list)):
# Unpack GNN layer and batch norm layer for this iteration
gnn = self.gnn_list[i]
bn = self.bn_list[i]
gnn_sum = self.gnn_sum_list[i]
bn_sum = self.bn_sum_list[i]

# Apply GNN and batch norm layer
h1 = bn(gnn(x, edge_index, edge_attr))

# Compute graph offset and node indices
# for summing node features across subgraphs
num_nodes_per_subgraph = batched_data.num_nodes_per_subgraph
tmp = torch.cat([
torch.zeros(1, device=num_nodes_per_subgraph.device,
dtype=num_nodes_per_subgraph.dtype),
torch.cumsum(num_nodes_per_subgraph, dim=0)
])
graph_offset = tmp[batch]
node_idx = graph_offset + batched_data.subgraph_n_id

# Sum node features across subgraphs
x_sum = scatter(src=x, index=node_idx, dim=0, reduce="mean")

# Information sharing component
h2 = bn_sum(
gnn_sum(
x_sum, batched_data.orig_edge_index,
batched_data.orig_edge_attr
if edge_attr is not None else edge_attr))

# Apply activation function and update node features
# for next iteration
x = F.relu(h1 + h2[node_idx])

# Pool node features across subgraphs to
# obtain subgraph representations
h_subgraph = subgraph_pool(x, batched_data, global_mean_pool)

# Pool subgraph representations to obtain graph representation
h_graph = scatter(src=h_subgraph, index=batched_data.subgraph_id_batch,
dim=0, reduce="mean")

# Apply final layers and return output
return self.final_layers(h_graph)

Measuring performance on datasets

To test the performance of the ESAN models and the efficacy of our implementation, we performed experiments using a subset of datasets that were used by [2]:

  1. The synthetic datasets CSL [7], EXP, and CEXP [1]. These datasets are generated so that the MPNNs (equivalent to the WL test) are not able to perform better than a random guess.
  2. The MUTAG [4] dataset from TUDatasets [6], a small benchmark dataset in PyG via torch_geometric.datasets.TUDataset
  3. The ZINC dataset [8], a large benchmark dataset in PyG via torch_geometric.datastes.ZINC.

For each dataset, we experiment in two steps: first, we preprocess the dataset using subgraph policies; then, we train with the encoders. We test on the four subgraph policies, i.e. node-deleted, edge-deleted, ego-network, and ego-network-plus, and the two encoders DS-GNN and DSS-GNN.

Here, we present a made-up sample use case for testing on a simple dataset via DSS-GNN and the EGO+ policy just to explain the setup. A more detailed example on training on real TUDatasets can be found here in our Github repo.

from torch_geometric.datasets import TUDataset
from torch_geometric.nn import GraphConv

n_subgraphs = [2, 5, 10]
num_tasks = [1, 2, 3]

x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
data = Data(x=x, edge_index=edge_index)

# Specify subgraph policy
policy = 'ego_plus'
num_hops = 3
transform = subgraph_policy(policy, num_hops = num_hops)

# pre-process data by the chosen subgraph policy
subgraphs = transform(data)

# Create a Batch object from the input data
ls_subgraphs = [subgraphs] * n_subgraphs
batched_data = Batch.from_data_list(ls_subgraphs,
follow_batch=['subgraph_id'])

# Specify model hyperparameters
num_layers = 2
in_dim = 8
emb_dim = 16
GNNConv = GraphConv

def feature_encoder(x):
return x

model = DSSnetwork(num_layers, in_dim, emb_dim, num_tasks, feature_encoder,
GNNConv)
output = model(batched_data)

The results of our experiments demonstrate the increased expressive power of the ESAN model. For the synthetic datasets CSL, EXP, and CEXP, our ESAN implementation achieved 100% accuracy, reproducing the outcome in [2]. This result is expected, as the subgraph-augmentation captures more structural information compared to the locality of the WL test or that of MPNNs. For the MUTAG and ZINC datasets, our implementation also reproduces the results from [2], as shown in Table 1. For all of the listed subgraph policies and for both graph encoders, ESAN outperforms the base encoder on MUTAG and ZINC.

Table 1. Performance of our implementation of ESAN on MUTAG and ZINC.

In conclusion, subgraph GNNs can be an answer to the question of increasing the expressive power of GNNs. To further explore the power of ESAN, we can think about adding support for directed graphs as well. To understand subgraph GNNs better, more questions can be asked: how can we choose the best subgraph policy for a given dataset? How can we make subgraph GNNs less computationally expensive? How can we further increase the expressive power on top of subgraph GNNs?

Our implementation the subgraph policies and the models of the ESAN model in PyG, as well as relevant tests can be found in our Github repo. The demos and code for experiments can be found here.

References

[1] R. Abboud, I. I. Ceylan, M. Grohe, and T. Lukasiewicz. The surprising power of graph neural networks with random node initialization. ArXiv, abs/2010.01179, 2020.

[2] B. Bevilacqua, F. Frasca, D. Lim, B. Srinivasan, C. Cai, G. Balamurugan, M. M. Bronstein, and H. Maron. Equivariant subgraph aggregation networks. CoRR, abs/2110.02910, 2021.

[3] M. Bronstein. Using subgraphs for more expressive GNNs, Aug 2022.

[4] A. K. Debnath, R. L. Lopez de Compadre, G. Debnath, A. J. Shusterman, and C. Hansch. Structure-activity relationship of mutagenic aromatic and heteroaromatic nitro compounds. Correlation with molecular orbital energies and hydrophobicity. Journal of Medicinal Chemistry, 34(2), 786–797, 1991.

[5] P. A. Papp, K. Martinkus, L. Faber, and R. Wattenhofer. DropGNN: random dropouts increase the expressiveness of graph neural networks. Advances in Neural Information Processing Systems, 34:21997–22009, 2021.

[6] C. Morris, N. M. Kriege, F. Bause, K. Kersting, P. Mutzel, and M. Neumann. Tudataset: A collection of benchmark datasets for learning with graphs. In ICML 2020 Workshop on Graph Representation Learning and Beyond (GRL+ 2020), 2020.

[7] R. Murphy, B. Srinivasan, V. Rao, andB. Ribeiro. Relational Pooling for Graph Representations. Proceedings of the 36th International Conference on Machine Learning, in Proceedings of Machine Learning Research, 97:4663–4673, 2019.

[8] T. Sterling and J. J. Irwin, ZINC 15 — Ligand discovery for everyone, J. Chem. Inf. Model, 2015.

[9] K. Xu, W. Hu, J. Leskovec, and S. Jegelka. How powerful are graph neural networks? 2019.

[10] L. Zhao, W. Jin, L. Akoglu, and N. Shah. From stars to subgraphs: Uplifting any GNN with local structure awareness, 2022.

--

--