A tour of PyG’s data loaders
By Grant Uy and Huijian Cai as part of the Stanford CS224W course project.
Graph neural networks (GNNs) are powerful tools with broad applicability to many domains because real-world networks, like social networks, are often well modeled by graphs.
However, real-world networks tend to be massive and most useful at scale. For example, imagine Amazon’s “product co-purchasing” graph, where edges connect products (nodes) that are bought together. This is useful at Amazon’s scale for recommending or classifying products, but with millions of products, the graph size makes applying ML difficult.
Such graphs are too large to fit into GPU memory, so when implementing a GNN, we have to batch the data. While this is a simple idea in principle, deciding on how to split up the graph is surprisingly nuanced and affects downstream performance, both training time and accuracy.
PyG, built on PyTorch, is a powerful GNN library that provides several different graph data loaders, which we’ll explore in this post!
Loader basics
PyG’s loaders are built on PyTorch DataLoader
s. A DataLoader
is an abstraction over a dataset that enables batching.
Each PyG loader accepts a Data
object, which represents a graph, as well as additional parameters that control how nodes and edges are sampled into batches.
RandomNodeLoader
The most straightforward loader is RandomNodeLoader
, which randomly samples batches of nodes from the input graph. For example, given a Data
object data
, we can construct one like this:
from torch_geometric.loader import RandomNodeLoader
random_loader = RandomNodeLoader(data, num_parts=5, shuffle=True)
num_parts
controls the batch size. This RandomNodeLoader
splits the graph into 5 subgraphs by partitioning the N nodes into 5 subsets. shuffle
tells PyG to pick these subsets randomly. Otherwise, the first batch would always contain the first N/5 nodes (by node ID), and so on. Generally, shuffling is good for training so that the model sees different batches for every epoch.
For example, here’s a toy graph batched with this loader:
Notice that the subgraphs tend to be disconnected since the partitioning was random. While RandomNodeLoader
is a useful introduction, we’ll see that more cleverly dividing up the graph can improve performance!
This Colab contains a simple library we built to generate and visualize batches in toy graphs if you want to play around with these loaders! We used it for all of the visualizations in this post.
Retrieving batches
Given a loader, we can retrieve the batches by iterating through it, like for PyTorch DataLoader
s. In a training loop, this might look like:
for batch in random_loader:
# "batch" is another Data instance representing the subgraph
# ...
# if the original Data had properties like "x" and "edge_index",
# batch will have relevant slices of them:
out = model(batch.x, batch.edge_index)
# ...
len(random_loader)
returns the total number of batches, and the batch indices can be retrieved in a loop via enumerate(random_loader)
.
Data splits
In practice, the graph is usually partitioned into training, validation, and test sets. To handle this, we typically stash train_mask
, val_mask
, and test_mask
attributes in the original Data
, and then those masks will be propagated to the batches.
For example, PyG’s RandomNodeSplit
transform can be used to add these masks to the Data
before passing it to the loader. (See PyG’s documentation for details, which are beyond the scope of this post!)
Using this method, each batch will contain a mix of train, validation, and test nodes. Here’s one batch from above:
In this case, we’d use the same loader regardless of which split we’re iterating over (e.g., during the training loop vs. during model evaluation), but we’d use the masks to select the appropriate subset of nodes. For example, batch.x[batch.train_mask]
would select only the x
values for nodes in the training set.
Some loaders, such as NeighborLoader
, specially handle splits and accept the split mask (e.g., train_mask
) when the loader is constructed. In such cases, we’ll typically have multiple loaders: train_loader
, val_loader
, and test_loader
. More on this below!
Other useful parameters
A couple more parameters that are the same as in PyTorch are:
num_workers
: enables multi-process data loading [documentation]drop_last
: skips the last batch if it’s a different size (e.g., when batch size doesn’t perfectly divide total size) [documentation]
NeighborLoader
NeighborLoader
is another widely applicable node-based loader type. We’ll dive into it since it’s so foundational!
Basics
NeighborLoader
partitions nodes into batches like RandomNodeLoader
, but for every batch, NeighborLoader
also adds neighboring nodes to the subgraph. Each node is a “root” (or “starter”) node in exactly 1 batch, but that node can appear in other batches if pulled in as a neighbor.
A NeighborLoader
can be instantiated like:
from torch_geometric.loader import NeighborLoader
neighbor_loader = NeighborLoader(data,
num_neighbors=[3],
batch_size=2,
shuffle=True)
Suppose data
contains 40 nodes. Then, this will partition the 40 nodes into 20 batches of size 2. For each batch, it will add3 neighbors for each of the 2 “root” nodes. Hence, each batch will be a subgraph with 8 total nodes (not accounting for overlaps). For example:
Unlike RandomNodeLoader
, this results in a more connected subgraph!
Parameterization
By default, the resulting subgraph only contains edges that were “followed” when adding neighbors. To change this behavior, pass directed=False
to the NeighborLoader
. (Despite its name, directed
isn’t used to distinguish directed and undirected graphs!) When directed=False
, the full node-induced subgraph is used instead. In the example above, directed=False
would add these 2 red edges:
Additionally, NeighborLoader
supports multiple “hops” for adding neighbors. In the example above, num_neighbors=[3]
indicates that we want 1 round of adding 3 neighbors.
Suppose we set num_neighbors=[3,2]
(with batch_size=2
). Any given batch still starts with 2 “root” nodes, with 3 neighbors added for each. But, for each of the 6 added neighbors, 2 neighbors of those neighbors are added. Here’s 1 color-coded batch for our graph:
Adding more hops increases the neighborhood sizes but exponentially increases the subgraph’s size! We’ll see later how this tradeoff affects performance.
Data splits
NeighborLoader
accepts a split mask like train_mask
via the parameter called input_nodes
. For example:
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
num_neighbors=[3,2], batch_size=2,
directed=False, shuffle=True)
val_loader = NeighborLoader(data, input_nodes=data.val_mask,
num_neighbors=[3,2], batch_size=2,
directed=False, shuffle=True)
test_loader = NeighborLoader(data, input_nodes=data.test_mask,
num_neighbors=[3,2], batch_size=2,
directed=False, shuffle=True)
When input_nodes
is set to a mask, only matching nodes will be selected as “root nodes.” These loaders will be used separately; for example, we can iterate over onlytrain_loader
during model training.
Each batch from train_loader
can still contain nodes not in the training set! (Similar for val_loader
and test_loader
.) The mask only affects “root” nodes, not which nodes are added as neighbors. Hence, to avoid validation/test nodes during training, you still have to apply train_mask
. For example, a training loop might look like:
for batch in train_loader:
# ...
# Run the model on all nodes in the batch so that message passing
# still works for non-training nodes
out = model(batch.x, batch.edge_index)
# Compute the loss only for training nodes.
loss = loss_fn(out[batch.train_mask], batch.y[batch.train_mask])
# ...
For a more complete example, see our Colab.
Applications
We recommend NeighborLoader
as a good starting point for most node-based tasks. RandomNodeLoader
is simpler but tends to perform poorly on large, sparse graphs (as most real-world graphs are!) due to the disconnectedness of batches. NeighborLoader
also:
- produces more connected subgraphs but is still simple
- conveniently handles data splits (e.g., won’t produce edge-case batches without training nodes)
- supports heterogeneous graphs out-of-the-box!
Performance comparison
We set up a basic ML pipeline for 2 Open Graph Benchmark (OGB) [1] node classification datasets, ogbn-products and ogbn-proteins, and then varied NeighborLoader
parameters to see the effects on downstream performance, namely accuracy and training time. You can follow along in our Colab!
For simplicity, we used a 3-layer GCN provided by PyG, since our goal was not to optimize the rest of the pipeline.
ogbn-products is a product co-purchasing network from Amazon [2], as described above, with 200K training nodes. The goal is to classify products into 47 categories. Our NeighborLoader
s produced these results:
We had a few key takeaways:
- Models with smaller subgraphs (e.g.,
neighbor-1-256
, which hadnum_neighbors=[1]
andbatch_size=256
) trained faster but had worse accuracy. directed=True
(the default!) performed significantly worse thandirected=False
, since it reduces the connectivity of the subgraphs, and our 3-layer GCN considers 3-hop neighborhoods.- Loaders with more than 1 neighbor hop performed the best overall, but they also took significantly longer to train.
Here’s the same data, sorted by training time to show the relationship with accuracy:
However, parameter effects vary wildly by use case (e.g., problem domain) due to different graph structures! We repeated the same experiment for ogbn-proteins, which encodes associations between proteins [3]. For a multi-label binary classification task for protein functions, we had the following results:
For instance, notice that directed=True
performed better for this case! Loader experimentation is important for optimizing any ML pipeline.
Other node-based loaders
ClusterLoader
ClusterLoader
relies on ClusterData
, which partitions the node into fixed clusters:
from torch_geometric.loader import ClusterData, ClusterLoader
# In practice, num_parts would be much higher than 4!
cluster_data = ClusterData(data, num_parts=4)
cluster_loader = ClusterLoader(cluster_data, batch_size=2, shuffle=True)
num_parts
is the desired number of clusters, like for RandomNodeLoader
. batch_size
is the number of clusters to include per batch. In this case, ClusterData
forms 4 node clusters, and ClusterLoader
randomly picks 2 of those clusters for each batch. The resulting batch represents the node-induced subgraph for all of the nodes in those 2 clusters. For example:
Notably, this includes the edges between the two clusters.
ClusterLoader
is useful when the underlying graph has “community-like” structures, like in our example graph or in a social network. ClusterData
can handle “communities” of varying sizes, whereas NeighborLoader
always picks subgraphs of roughly the same shape. Additionally, NeighborLoader
can’t guarantee that an entire related cluster of nodes will make it into the same batch, especially when the cluster is large.
For ClusterLoader
, every node appears in exactly one batch, unlike for NeighborLoader
. Hence, ClusterLoader
treats nodes more equally, whereas highly-connected nodes appear more frequently inNeighborLoader
batches. Whether this is good depends on the use case: highly-connected nodes could be important (e.g., for product classification) or poorly representative of a population.
GraphSAINTRandomWalkSampler
GraphSAINTRandomWalkSampler
, similar to NeighborLoader
, picks batch_size
“root” nodes for each batch and then adds surrounding nodes, but instead of adding direct neighbors, it starts a random walk and adds all encountered nodes.
Unlike NeighborLoader
, the “root” nodes are not from a partition of the original set of nodes. Rather, they’re sampled with replacement and normalization. Check out the GraphSAINT paper [4] for details!
saint_walk_sampler = GraphSAINTRandomWalkSampler(data,
batch_size=2,
walk_length=6,
num_steps=3)
GraphSAINTRandomWalkSampler
’s key parameters are:
batch_size
: the number of “root” nodes (not each batch’s final number of nodes!)walk_length
: how many steps to take on each random walk starting from a “root” nodenum_steps
: the number of batches to create (not to be confused withwalk_length
!)
For example, the above parameters produce batches like:
This loader is useful when you care more about “depth” in the graph than “breadth.” NeighborLoader
is similar to breadth-first exploration, whereas random walks can be more like depth-first exploration. This can allow more information to pass between nodes that are farther away, especially when you’re using a deep GNN with many layers.
That said, there are other methods (outside the scope of this post) to increase graph connectivity even when using other loaders, such as adding a “virtual node” connected to all nodes [5].
Performance comparison
Choosing the right loader type has meaningful performance implications, sometimes even more than parameterization! We repeated the NeighborLoader
experiment on ogbn-products using different loader types (and fewer epochs):
A couple takeaways:
RandomNodeLoader
predictably performed poorly in terms of both accuracy and training timeGraphSAINTRandomWalkSampler
achieved better accuracy thanNeighborLoader
in significantly less time
Like parameterization, the choice of loader type also depends heavily on the use case and the graph structure. For ogbn-proteins, the same experiment had very different results:
In particular, notice that the NeighborLoader
outperformed the other types in this case!
While this analysis is neither rigorous nor perfect, it’ll still hopefully inspire you to try different loaders! See our Colab for details.
Even more loader types
We couldn’t cover every PyG loader type in depth, but a few worth calling out are:
HGTLoader
: useful when working with heterogeneous graphs because it samples different types of nodes in a balanced wayImbalancedLoader
: useful for classification tasks with low-frequency labels, because it allows for oversampling based on the node classShaDowKHopSampler
: like a more advanced version ofNeighborLoader
that more carefully picks subgraphs and locally smooths them
LinkNeighborLoader
PyG also supports loaders that edges instead of nodes. This is useful for edge-related tasks like link prediction. LinkNeighborLoader
, NeighborLoader
’s edge-based analog, is a great starting point!
Negative sampling
Link prediction is essentially binary classification, predicting whether an edge would exist. The graph’s existing edges are considered positive examples, but we also need to make predictions on negative examples, edges that are not actually in the graph. These “negative edges” need to be sampled too! Thankfully, PyG supports this out-of-the-box!
Parameters
LinkNeighborLoader
is constructed similarly to NeighborLoader
:
link_neighbor_loader = LinkNeighborLoader(
data,
num_neighbors=[2],
neg_sampling_ratio=1.0,
batch_size=2,
shuffle=True)
For each batch, LinkNeighborLoader
selects batch_size
positive edges and neg_sampling_ratio*batch_size
negative edges. Then, for each node (i.e., each edge endpoint), neighbors are added like for NeighborLoader
. For example:
The resulting batches have two edge indexes:
batch.edge_index
for edges for message passingbatch.edge_label_index
for positive/negative edges to be labeled (with the ground truth inbatch.edge_label
)
Data splits
LinkNeighborLoader
is typically constructed in conjunction with RandomLinkSplit
(instead of using masks) to handle data splits:
transform = T.RandomLinkSplit(
num_val=0.2, # fraction of data held out for validation
num_test=0.2, # fraction of data held out for test
neg_sampling_ratio=2.0,
add_negative_train_samples=False,
)
train_data, val_data, test_data = transform(data)
batch_size = 2
num_neighbors = [3,2]
train_loader = LinkNeighborLoader(train_data,
num_neighbors=num_neighbors,
batch_size=batch_size,
# these are only set for train!
neg_sampling_ratio=2.0,
shuffle=True)
val_loader = LinkNeighborLoader(val_data,
num_neighbors=num_neighbors,
batch_size=batch_size)
test_loader = LinkNeighborLoader(test_data,
num_neighbors=num_neighbors,
batch_size=batch_size)
RandomLinkSplit
adds negative edges to val_data
and test_data
(but not for train_data
due to add_negative_train_samples=False
) when they’re constructed so that these negative edges remain constant for every run through the validation/test data. Conversely, negative edges for training are added by LinkNeighborLoader
to each batch so that the negative edges differ for every training epoch. Hence, neg_sampling_ratio
is only set for train_loader
!
Applications
Unlike node-based loaders, LinkNeighborLoader
is a more obvious choice for edge-related tasks like link prediction and edge property prediction. There aren’t many options for edge-based loaders, and LinkNeighborLoader
is the only one that supports on-the-fly negative sampling.
Performance comparison
We created this Colab to demonstrate how to apply an edge-level loader to a heterogenous graph, largely following PyG’s blog post, which goes into more detail! Similar to above, our focus was to demonstrate how LinkNeighborLoader
parameters affect downstream performance.
We used the MovieLens dataset [6], a heterogeneous graph containing ratings (edges) from users for movies (both nodes), for link prediction. Predicting missing ratings can be useful in recommending movies to users.
Using a model similar to the one in the blog post, we got the following results from several LinkNeighborLoader
parameterizations:
The main takeaway is that even seemingly minor parameter changes can significantly affect performance, but a couple other call-outs are:
- large batch sizes significantly decreased overall training time since they reduced the total overhead from non-labeled message passing edges
- adding more neighbors increased training time but didn’t necessarily improve performance
See our Colab and the original blog post for more details!
Summary
To recap, we recommend NeighborLoader
as a starting point for node-based tasks and NeighborLinkLoader
for edge-based tasks.
For node-based tasks, we’d suggest additionally looking into:
ClusterLoader
if your graph has many “community-like” structuresGraphSAINTRandomWalkSampler
if “far away” nodes matter more (relative to local neighborhoods), especially for deeper GNNsShaDowKHopSampler
for more advanced neighbor samplingHGTLoader
for handling heterogeneous nodes in a more balanced wayImbalancedLoader
for skewed node classification problems
Closing
We hope this was a helpful tour of PyG’s loaders! If you want to dive into the code, here are all the Colabs we used in this post:
- Toy graph generation and visualization
- Node classification pipeline for ogbn-products and ogbn-proteins
- Link prediction pipeline for MovieLens
- Data visualization
Further reading
- PyG loader documentation
- PyTorch data loading tutorial
- Link Prediction on Heterogeneous Graphs with PyG
References
[1] Hu et al. 2020. https://arxiv.org/abs/2005.00687
[2] Bhatia et al. 2016. http://manikvarma.org/downloads/XC/XMLRepository.html
[3] Szklarczyk et al. 2019. https://doi.org/10.1093/nar/gky1131
[4] Zeng et al. 2020. https://arxiv.org/abs/1907.04931
[5] CS224W 2023 lecture 6, slide 18: http://web.stanford.edu/class/cs224w/slides/06-GNN3.pdf
[6] Harper et al. 2015. https://doi.org/10.1145/2827872