Intro to Graph Neural Networks with cuGraph-PyG

Alex Barghi
RAPIDS AI
Published in
6 min readMay 11, 2023

Introduction

Graph Neural Networks (GNNs) are one of the fastest-growing tools in machine learning. GNNs combine a rich array of feature data (similar to the input of a traditional neural network) with network structure (represented as a graph). Depending on the workflow and models used, GNNs can be used for node property prediction, edge property prediction, and link prediction. They support a wide variety of use cases, including fraud detection, relationship prediction, and molecule generation. For training, many GNNs use a process called sampling to aggregate information from an entire graph into a collection of batched subgraphs.

PyTorch Geometric (PyG) is one of the leading GNN frameworks. Through an ongoing partnership between the PyG team and NVIDIA, PyG users have access to cutting-edge GPU acceleration for both model performance and sampling.

This blog covers how to use cuGraph to train a GNN with PyG, as well as how to convert an existing PyG workflow to one with cuGraph.

Why cuGraph?

cuGraph brings cutting-edge acceleration to both small and massive-scale graph data, offering scalable performance, from small graphs on a single GPU to trillion-edge graphs spread across multiple nodes with multiple GPUs, for 30+ common algorithms, such as PageRank, breadth-first search, and uniform neighbor sampling. The cuGraph-PyG library is a drop-in extension to PyG that enables cuGraph acceleration in PyG with minimal code change, allowing users already familiar with GNNs to incorporate cuGraph into their workflows.

In addition to the cuGraph-PyG library, cuGraph also offers accelerated models built upon the cuGraph-ops library. Current PyG users can take advantage of cuGraph-ops performance boosts by using CuGraphSAGEConv, CuGraphGATConv, and cuGraphRGCNConv in place of the default SAGEConv, GATConv, and RGCNConv models.

More information on cuGraph performance will be included in an upcoming blog.

The cuGraph-PyG Ecosystem

Illustration of the cuGraph-PyG Ecosystem

Getting Started

To begin, we select our dataset and define the problem to be solved using GNNs. In this case, we will use the Microsoft Academic Graph (MAG) dataset, which is a publicly available reference dataset for GNNs. MAG contains three types of vertices (authors, papers, and institutions), and three types of edges (author-writes-paper, author-affiliated with-institution, and paper-cites-paper). To keep things simple, we’ll just use paper vertices and the paper-cites-paper edges, discarding the rest of the graph. Our problem is as follows:

Given a graph consisting of academic institutions, authors, papers, and their citations, as well as a set additional features for each paper, can we predict the venue a paper was published in with reasonable accuracy?

We start tackling this problem by loading the graph. MAG is publicly available through the OGB Python package, which can be downloaded through Pip or Conda.

# Load the MAG dataset into memory
# Will automatically download the dataset if needed
dataset = NodePropPredDataset(name="ogbn-mag")

# Get the paper edge index and number of nodes
data = dataset[0]
edge_index = data[0]["edge_index_dict"]["paper", "cites", "paper"]
num_vertices = data[0]["num_nodes_dict"]["paper"]

# Get the paper features and labels
paper_features = data[0]["node_feat_dict"]["paper"]
paper_labels = data[1]["paper"].T[0]

# Move the edge index, features, and labels to PyTorch on the GPU
edge_index = torch.as_tensor(edge_index, device="cuda")
paper_features = torch.as_tensor(paper_features, device="cuda")
paper_labels = torch.as_tensor(paper_labels, device="cuda")

MAG Edge Index:

{
('paper', 'cites', 'paper'): tensor(2 x 5416271),
}

MAG Num Vertices:

{
'paper': tensor(736389)
}

MAG Paper Features:

{
'paper': tensor(736389 x 128)
}

An edge is represented by its source and target vertices, so the edge index is a 2 x E array, where E is the number of edges in the graph. We’ll want to symmetrize the input edgelist so that for each paper p, we’re aware of both the papers citing p and the papers that p cites. This can be done using the stack and concatenate functions in PyTorch.

edge_index = torch.stack([
torch.concatenate([edge_index[0], edge_index[1]]),
torch.concatenate([edge_index[1], edge_index[0]]),
])

The edge index is now a 2 x 10,832,542 tensor.

Next, we need to define the train/test splits for the data. We’ll select the train nodes, then create a train mask that will be used to filter to only the train nodes or only the test nodes when needed.

from cuml.model_selection import train_test_split

num_papers = num_vertices["paper"]

train_nodes, _ = train_test_split(
torch.arange(num_papers, device="cuda"),
train_size=0.9
)
train_nodes = torch.as_tensor(train_nodes, device="cuda")

train_mask = torch.full((num_papers,), False, dtype=torch.bool, device="cuda")
train_mask[train_nodes] = True

At this point, all the necessary preprocessing is done, and the graph can be loaded into cuGraph. We’ll start by putting the features in cuGraph’s feature store, and then use the feature store, along with the edge index and number of vertices to construct a CuGraphStore. CuGraphStore implements PyG’s remote backend interface and is the primary interface between cuGraph and PyG.

import cugraph
from cugraph_pyg.data import CuGraphStore

fs = cugraph.gnn.FeatureStore(backend="torch")

# Add the paper features
fs.add_data(paper_features, "paper", "x")

# Add the vertex labels ("ground truth" data)
fs.add_data(paper_labels, "paper", "y")

# Add the train mask
fs.add_data(train_mask, "paper", "train")

Next we’ll need to define our model. We’ll use cuGraph-ops models as building blocks to construct a 3-layer model consisting of two CuGraphSAGEConv layers and a linear layer. There are 128 input channels corresponding to the 128 input features for each vertex, 64 hidden channels, and 349 output channels corresponding to the 349 venues.

The SAGEConv layer implemented by CuGraphSAGE in cuGraph-ops is explained in depth in this paper. GraphSAGE works by training a set of aggregator functions, rather than a per-node embedding vector. This allows generalization of training to graphs without a fixed size (batches can have any number of nodes), and supports inference on unseen data without sacrificing performance.

from torch_geometric.nn import CuGraphSAGEConv
import torch.nn.functional as F

class CuGraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
super().__init__()

self.convs = torch.nn.ModuleList()
self.convs.append(CuGraphSAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 1):
conv = CuGraphSAGEConv(hidden_channels, hidden_channels)
self.convs.append(conv)

self.lin = torch.nn.Linear(hidden_channels, out_channels)

def forward(self, x, edge, size):
edge_csc = CuGraphSAGEConv.to_csc(edge, (size[0], size[0]))
for conv in self.convs:
x = conv(x, edge_csc)[: size[1]]
x = F.relu(x)
x = F.dropout(x, p=0.5)

return self.lin(x)

model = CuGraphSAGE(128, 64, 349, 3).to(torch.float32).to("cuda")
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

With the graph and model ready, we can start training. Since MAG is a fairly large graph, we’ll want to sample it. As mentioned before, cuGraph is designed for fast graph sampling, and it directly interfaces with PyG through the remote graph interface. To load samples from cuGraph, we’ll use the CuGraphNeighborLoader, which implements the NodeLoader interface in PyG. Like the NeighborLoader in base PyG, the CuGraphNeighborLoader uses the uniform neighbor sampling algorithm to sample the graph into batched subgraphs.

The uniform neighbor sampling algorithm is similar to a breadth-first search algorithm, starting from b input vertices, where b is the batch size, and limiting the number of neighbors in each hop to nk for each hop k. k and nk are set using the num_neighbors parameter (also called fanout), where the length of the provided list is k and each entry is nk . For example, for a batch size of 4, and num_neighbors of [2, 2, 3], we will get the subgraph shown below. For many training workflows, a batch size of 500 and fanout of [10, 25] is effective.

Illustration of Uniform Neighbor Sampling

For each training epoch, we’ll create a new loader and generate new samples. For simplicity, we’ll just show one epoch here.

from cugraph_pyg.loader import CuGraphNeighborLoader

loader = CuGraphNeighborLoader(
cugraph_store,
train_nodes,
batch_size=500,
num_neighbors=[10, 25]
)

total_loss = 0.0
for hetero_data in loader:
mask = hetero_data.train_dict["paper"]
y_true = hetero_data.y_dict["paper"]

y_pred = model(
hetero_data.x_dict["paper"].to(device).to(torch.float32),
hetero_data.edge_index_dict[("paper", "cites", "paper")].to(device),
(len(y_true), len(y_true)),
)

y_true = F.one_hot(
y_true[mask].to(torch.int64), num_classes=349
).to(torch.float32)

y_pred = y_pred[mask]

loss = F.cross_entropy(y_pred, y_true)

optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()

print(f"loss after epoch {total_loss / num_batches}")

Converting an Existing PyG Workflow to cuGraph-PyG

To convert an existing PyG workflow to a cuGraph-PyG workflow, there are three simple steps:

  1. Use cuGraph-ops models (i.e. CuGraphSAGEConv) in place of the native PyG model (i.e. SAGEConv)
  2. Create a CuGraphStore object instead of a PyG Data or HeteroData object
  3. Use the CuGraphNeighborLoader in place of the native PyG NeighborLoader.

The cuGraph-ops models, CuGraphStore, and CuGraphNeighborLoader all follow the PyG API and are drop-in replacements for the PyG equivalents (Conv, GraphStore/FeatureStore, NeighborLoader).

Conclusion

As you can see, using cuGraph for GNN training is incredibly simple. With minimal code change, the power of accelerated models and ultra-fast, ultra-scalable sampling is in your hands.

cuGraph is continuously adding new features, including new and better support for GNNs. Coming soon will be a blog on GNN training with cuGraph-DGL, and another blog on using cuGraph to scale GNN training to trillions of edges.

Notes

cuGraph-PyG Examples Can be Found on GitHub.

“PyG” and the PyG logo are property of PyG (https://pyg.org/)

“PyTorch” is property of the Linux Foundation (https://pytorch.org/)

--

--

Alex Barghi
RAPIDS AI

Alex Barghi is a senior software engineer on the RAPIDS Graph Team at NVIDIA. She primarily works on cuGraph, and is the lead for cuGraph integration with PyG.