A tour of PyG’s data loaders

Grant Uy
Stanford CS224W GraphML Tutorials
12 min readMay 15, 2023

By Grant Uy and Huijian Cai as part of the Stanford CS224W course project.

Graph neural networks (GNNs) are powerful tools with broad applicability to many domains because real-world networks, like social networks, are often well modeled by graphs.

However, real-world networks tend to be massive and most useful at scale. For example, imagine Amazon’s “product co-purchasing” graph, where edges connect products (nodes) that are bought together. This is useful at Amazon’s scale for recommending or classifying products, but with millions of products, the graph size makes applying ML difficult.

Such graphs are too large to fit into GPU memory, so when implementing a GNN, we have to batch the data. While this is a simple idea in principle, deciding on how to split up the graph is surprisingly nuanced and affects downstream performance, both training time and accuracy.

PyG, built on PyTorch, is a powerful GNN library that provides several different graph data loaders, which we’ll explore in this post!

Loader basics

PyG’s loaders are built on PyTorch DataLoaders. A DataLoader is an abstraction over a dataset that enables batching.

Each PyG loader accepts a Data object, which represents a graph, as well as additional parameters that control how nodes and edges are sampled into batches.

RandomNodeLoader

The most straightforward loader is RandomNodeLoader, which randomly samples batches of nodes from the input graph. For example, given a Data object data, we can construct one like this:

from torch_geometric.loader import RandomNodeLoader

random_loader = RandomNodeLoader(data, num_parts=5, shuffle=True)

num_parts controls the batch size. This RandomNodeLoader splits the graph into 5 subgraphs by partitioning the N nodes into 5 subsets. shuffle tells PyG to pick these subsets randomly. Otherwise, the first batch would always contain the first N/5 nodes (by node ID), and so on. Generally, shuffling is good for training so that the model sees different batches for every epoch.

For example, here’s a toy graph batched with this loader:

Each node is colored based on which of the 5 batches it’s in.

Notice that the subgraphs tend to be disconnected since the partitioning was random. While RandomNodeLoader is a useful introduction, we’ll see that more cleverly dividing up the graph can improve performance!

This Colab contains a simple library we built to generate and visualize batches in toy graphs if you want to play around with these loaders! We used it for all of the visualizations in this post.

Retrieving batches

Given a loader, we can retrieve the batches by iterating through it, like for PyTorch DataLoaders. In a training loop, this might look like:

for batch in random_loader:
# "batch" is another Data instance representing the subgraph

# ...

# if the original Data had properties like "x" and "edge_index",
# batch will have relevant slices of them:
out = model(batch.x, batch.edge_index)

# ...

len(random_loader) returns the total number of batches, and the batch indices can be retrieved in a loop via enumerate(random_loader).

Data splits

In practice, the graph is usually partitioned into training, validation, and test sets. To handle this, we typically stash train_mask, val_mask, and test_mask attributes in the original Data , and then those masks will be propagated to the batches.

For example, PyG’s RandomNodeSplit transform can be used to add these masks to the Data before passing it to the loader. (See PyG’s documentation for details, which are beyond the scope of this post!)

Using this method, each batch will contain a mix of train, validation, and test nodes. Here’s one batch from above:

The red batch is itself split into training (same red as above), validation (dark red), and test (light red).

In this case, we’d use the same loader regardless of which split we’re iterating over (e.g., during the training loop vs. during model evaluation), but we’d use the masks to select the appropriate subset of nodes. For example, batch.x[batch.train_mask] would select only the x values for nodes in the training set.

Some loaders, such as NeighborLoader, specially handle splits and accept the split mask (e.g., train_mask) when the loader is constructed. In such cases, we’ll typically have multiple loaders: train_loader, val_loader, and test_loader. More on this below!

Other useful parameters

A couple more parameters that are the same as in PyTorch are:

  • num_workers: enables multi-process data loading [documentation]
  • drop_last: skips the last batch if it’s a different size (e.g., when batch size doesn’t perfectly divide total size) [documentation]

NeighborLoader

NeighborLoader is another widely applicable node-based loader type. We’ll dive into it since it’s so foundational!

Basics

NeighborLoader partitions nodes into batches like RandomNodeLoader, but for every batch, NeighborLoader also adds neighboring nodes to the subgraph. Each node is a “root” (or “starter”) node in exactly 1 batch, but that node can appear in other batches if pulled in as a neighbor.

A NeighborLoader can be instantiated like:

from torch_geometric.loader import NeighborLoader

neighbor_loader = NeighborLoader(data,
num_neighbors=[3],
batch_size=2,
shuffle=True)

Suppose data contains 40 nodes. Then, this will partition the 40 nodes into 20 batches of size 2. For each batch, it will add3 neighbors for each of the 2 “root” nodes. Hence, each batch will be a subgraph with 8 total nodes (not accounting for overlaps). For example:

One batch from this NeighborLoader.

Unlike RandomNodeLoader, this results in a more connected subgraph!

Parameterization

By default, the resulting subgraph only contains edges that were “followed” when adding neighbors. To change this behavior, pass directed=False to the NeighborLoader. (Despite its name, directed isn’t used to distinguish directed and undirected graphs!) When directed=False, the full node-induced subgraph is used instead. In the example above, directed=False would add these 2 red edges:

The red edges connect selected nodes but were not “followed” when adding neighbors.

Additionally, NeighborLoader supports multiple “hops” for adding neighbors. In the example above, num_neighbors=[3] indicates that we want 1 round of adding 3 neighbors.

Suppose we set num_neighbors=[3,2] (with batch_size=2). Any given batch still starts with 2 “root” nodes, with 3 neighbors added for each. But, for each of the 6 added neighbors, 2 neighbors of those neighbors are added. Here’s 1 color-coded batch for our graph:

The 2 “root nodes” are red, the 3 added neighbors for each “root” node are in blue, and the neighbors-of-neighbors are in green. (There are also some overlaps, hence the green edge connecting blue nodes.)

Adding more hops increases the neighborhood sizes but exponentially increases the subgraph’s size! We’ll see later how this tradeoff affects performance.

Data splits

NeighborLoader accepts a split mask like train_mask via the parameter called input_nodes. For example:

train_loader = NeighborLoader(data, input_nodes=data.train_mask,
num_neighbors=[3,2], batch_size=2,
directed=False, shuffle=True)

val_loader = NeighborLoader(data, input_nodes=data.val_mask,
num_neighbors=[3,2], batch_size=2,
directed=False, shuffle=True)

test_loader = NeighborLoader(data, input_nodes=data.test_mask,
num_neighbors=[3,2], batch_size=2,
directed=False, shuffle=True)

When input_nodes is set to a mask, only matching nodes will be selected as “root nodes.” These loaders will be used separately; for example, we can iterate over onlytrain_loader during model training.

Each batch from train_loader can still contain nodes not in the training set! (Similar for val_loader and test_loader.) The mask only affects “root” nodes, not which nodes are added as neighbors. Hence, to avoid validation/test nodes during training, you still have to apply train_mask. For example, a training loop might look like:

for batch in train_loader:
# ...

# Run the model on all nodes in the batch so that message passing
# still works for non-training nodes
out = model(batch.x, batch.edge_index)

# Compute the loss only for training nodes.
loss = loss_fn(out[batch.train_mask], batch.y[batch.train_mask])

# ...

For a more complete example, see our Colab.

Applications

We recommend NeighborLoader as a good starting point for most node-based tasks. RandomNodeLoader is simpler but tends to perform poorly on large, sparse graphs (as most real-world graphs are!) due to the disconnectedness of batches. NeighborLoader also:

  • produces more connected subgraphs but is still simple
  • conveniently handles data splits (e.g., won’t produce edge-case batches without training nodes)
  • supports heterogeneous graphs out-of-the-box!

Performance comparison

We set up a basic ML pipeline for 2 Open Graph Benchmark (OGB) [1] node classification datasets, ogbn-products and ogbn-proteins, and then varied NeighborLoader parameters to see the effects on downstream performance, namely accuracy and training time. You can follow along in our Colab!

For simplicity, we used a 3-layer GCN provided by PyG, since our goal was not to optimize the rest of the pipeline.

ogbn-products is a product co-purchasing network from Amazon [2], as described above, with 200K training nodes. The goal is to classify products into 47 categories. Our NeighborLoaders produced these results:

Naming scheme: “neighbor-{num_neighbors}-{batch_size}-{directed}”

We had a few key takeaways:

  • Models with smaller subgraphs (e.g., neighbor-1-256, which had num_neighbors=[1] and batch_size=256) trained faster but had worse accuracy.
  • directed=True (the default!) performed significantly worse than directed=False, since it reduces the connectivity of the subgraphs, and our 3-layer GCN considers 3-hop neighborhoods.
  • Loaders with more than 1 neighbor hop performed the best overall, but they also took significantly longer to train.

Here’s the same data, sorted by training time to show the relationship with accuracy:

Longer training time tended to produce better accuracy, but the tradeoff isn’t that simple!

However, parameter effects vary wildly by use case (e.g., problem domain) due to different graph structures! We repeated the same experiment for ogbn-proteins, which encodes associations between proteins [3]. For a multi-label binary classification task for protein functions, we had the following results:

Naming scheme: “proteins-neighbor-{num_neighbors}-{batch_size}-{directed}”

For instance, notice that directed=True performed better for this case! Loader experimentation is important for optimizing any ML pipeline.

Other node-based loaders

ClusterLoader

ClusterLoader relies on ClusterData, which partitions the node into fixed clusters:

from torch_geometric.loader import ClusterData, ClusterLoader

# In practice, num_parts would be much higher than 4!
cluster_data = ClusterData(data, num_parts=4)
cluster_loader = ClusterLoader(cluster_data, batch_size=2, shuffle=True)

num_parts is the desired number of clusters, like for RandomNodeLoader. batch_size is the number of clusters to include per batch. In this case, ClusterData forms 4 node clusters, and ClusterLoader randomly picks 2 of those clusters for each batch. The resulting batch represents the node-induced subgraph for all of the nodes in those 2 clusters. For example:

Left: clusters from ClusterData. Right: 1 batch from ClusterLoader.

Notably, this includes the edges between the two clusters.

ClusterLoader is useful when the underlying graph has “community-like” structures, like in our example graph or in a social network. ClusterData can handle “communities” of varying sizes, whereas NeighborLoaderalways picks subgraphs of roughly the same shape. Additionally, NeighborLoader can’t guarantee that an entire related cluster of nodes will make it into the same batch, especially when the cluster is large.

For ClusterLoader, every node appears in exactly one batch, unlike for NeighborLoader. Hence, ClusterLoader treats nodes more equally, whereas highly-connected nodes appear more frequently inNeighborLoader batches. Whether this is good depends on the use case: highly-connected nodes could be important (e.g., for product classification) or poorly representative of a population.

GraphSAINTRandomWalkSampler

GraphSAINTRandomWalkSampler, similar to NeighborLoader, picks batch_size “root” nodes for each batch and then adds surrounding nodes, but instead of adding direct neighbors, it starts a random walk and adds all encountered nodes.

Unlike NeighborLoader, the “root” nodes are not from a partition of the original set of nodes. Rather, they’re sampled with replacement and normalization. Check out the GraphSAINT paper [4] for details!

saint_walk_sampler = GraphSAINTRandomWalkSampler(data,
batch_size=2,
walk_length=6,
num_steps=3)

GraphSAINTRandomWalkSampler’s key parameters are:

  • batch_size: the number of “root” nodes (not each batch’s final number of nodes!)
  • walk_length: how many steps to take on each random walk starting from a “root” node
  • num_steps: the number of batches to create (not to be confused with walk_length!)

For example, the above parameters produce batches like:

One batch from the GraphSAINTRandomWalkSampler.

This loader is useful when you care more about “depth” in the graph than “breadth.” NeighborLoader is similar to breadth-first exploration, whereas random walks can be more like depth-first exploration. This can allow more information to pass between nodes that are farther away, especially when you’re using a deep GNN with many layers.

That said, there are other methods (outside the scope of this post) to increase graph connectivity even when using other loaders, such as adding a “virtual node” connected to all nodes [5].

Performance comparison

Choosing the right loader type has meaningful performance implications, sometimes even more than parameterization! We repeated the NeighborLoader experiment on ogbn-products using different loader types (and fewer epochs):

Results for 4 loader types on ogbn-products. “neighbor_best_5epochs” is the top-performing model from above, but we used 10 epochs in the last experiment.

A couple takeaways:

  • RandomNodeLoader predictably performed poorly in terms of both accuracy and training time
  • GraphSAINTRandomWalkSampler achieved better accuracy than NeighborLoader in significantly less time

Like parameterization, the choice of loader type also depends heavily on the use case and the graph structure. For ogbn-proteins, the same experiment had very different results:

Results for different loader types on ogbn-proteins, similar to above.

In particular, notice that the NeighborLoader outperformed the other types in this case!

While this analysis is neither rigorous nor perfect, it’ll still hopefully inspire you to try different loaders! See our Colab for details.

Even more loader types

We couldn’t cover every PyG loader type in depth, but a few worth calling out are:

  • HGTLoader: useful when working with heterogeneous graphs because it samples different types of nodes in a balanced way
  • ImbalancedLoader: useful for classification tasks with low-frequency labels, because it allows for oversampling based on the node class
  • ShaDowKHopSampler: like a more advanced version of NeighborLoader that more carefully picks subgraphs and locally smooths them

LinkNeighborLoader

PyG also supports loaders that edges instead of nodes. This is useful for edge-related tasks like link prediction. LinkNeighborLoader, NeighborLoader’s edge-based analog, is a great starting point!

Negative sampling

Link prediction is essentially binary classification, predicting whether an edge would exist. The graph’s existing edges are considered positive examples, but we also need to make predictions on negative examples, edges that are not actually in the graph. These “negative edges” need to be sampled too! Thankfully, PyG supports this out-of-the-box!

Parameters

LinkNeighborLoader is constructed similarly to NeighborLoader:

link_neighbor_loader = LinkNeighborLoader(
data,
num_neighbors=[2],
neg_sampling_ratio=1.0,
batch_size=2,
shuffle=True)

For each batch, LinkNeighborLoader selects batch_size positive edges and neg_sampling_ratio*batch_size negative edges. Then, for each node (i.e., each edge endpoint), neighbors are added like for NeighborLoader. For example:

One batch, with and without negative sampling. Red edges are positive edges; blue ones are negative edges. Pink edges are added “neighbor” edges.

The resulting batches have two edge indexes:

  • batch.edge_index for edges for message passing
  • batch.edge_label_index for positive/negative edges to be labeled (with the ground truth in batch.edge_label)

Data splits

LinkNeighborLoader is typically constructed in conjunction with RandomLinkSplit (instead of using masks) to handle data splits:

transform = T.RandomLinkSplit(
num_val=0.2, # fraction of data held out for validation
num_test=0.2, # fraction of data held out for test
neg_sampling_ratio=2.0,
add_negative_train_samples=False,
)
train_data, val_data, test_data = transform(data)

batch_size = 2
num_neighbors = [3,2]
train_loader = LinkNeighborLoader(train_data,
num_neighbors=num_neighbors,
batch_size=batch_size,
# these are only set for train!
neg_sampling_ratio=2.0,
shuffle=True)
val_loader = LinkNeighborLoader(val_data,
num_neighbors=num_neighbors,
batch_size=batch_size)
test_loader = LinkNeighborLoader(test_data,
num_neighbors=num_neighbors,
batch_size=batch_size)

RandomLinkSplit adds negative edges to val_data and test_data (but not for train_data due to add_negative_train_samples=False) when they’re constructed so that these negative edges remain constant for every run through the validation/test data. Conversely, negative edges for training are added by LinkNeighborLoader to each batch so that the negative edges differ for every training epoch. Hence, neg_sampling_ratio is only set for train_loader!

Applications

Unlike node-based loaders, LinkNeighborLoader is a more obvious choice for edge-related tasks like link prediction and edge property prediction. There aren’t many options for edge-based loaders, and LinkNeighborLoader is the only one that supports on-the-fly negative sampling.

Performance comparison

We created this Colab to demonstrate how to apply an edge-level loader to a heterogenous graph, largely following PyG’s blog post, which goes into more detail! Similar to above, our focus was to demonstrate how LinkNeighborLoader parameters affect downstream performance.

We used the MovieLens dataset [6], a heterogeneous graph containing ratings (edges) from users for movies (both nodes), for link prediction. Predicting missing ratings can be useful in recommending movies to users.

Using a model similar to the one in the blog post, we got the following results from several LinkNeighborLoader parameterizations:

Naming scheme: “linkneighbor-{neg_sampling_ratio}-{num_neighbors}-{batch_size}-{directed}”

The main takeaway is that even seemingly minor parameter changes can significantly affect performance, but a couple other call-outs are:

  • large batch sizes significantly decreased overall training time since they reduced the total overhead from non-labeled message passing edges
  • adding more neighbors increased training time but didn’t necessarily improve performance

See our Colab and the original blog post for more details!

Summary

To recap, we recommend NeighborLoader as a starting point for node-based tasks and NeighborLinkLoader for edge-based tasks.

For node-based tasks, we’d suggest additionally looking into:

  • ClusterLoader if your graph has many “community-like” structures
  • GraphSAINTRandomWalkSampler if “far away” nodes matter more (relative to local neighborhoods), especially for deeper GNNs
  • ShaDowKHopSampler for more advanced neighbor sampling
  • HGTLoader for handling heterogeneous nodes in a more balanced way
  • ImbalancedLoader for skewed node classification problems

Closing

We hope this was a helpful tour of PyG’s loaders! If you want to dive into the code, here are all the Colabs we used in this post:

Further reading

References

[1] Hu et al. 2020. https://arxiv.org/abs/2005.00687
[2] Bhatia et al. 2016. http://manikvarma.org/downloads/XC/XMLRepository.html
[3] Szklarczyk et al. 2019. https://doi.org/10.1093/nar/gky1131
[4] Zeng et al. 2020. https://arxiv.org/abs/1907.04931
[5] CS224W 2023 lecture 6, slide 18: http://web.stanford.edu/class/cs224w/slides/06-GNN3.pdf
[6] Harper et al. 2015. https://doi.org/10.1145/2827872

--

--