6.1 Training GNN for Node Classification with Neighborhood Sampling

(δΈ­ζ–‡η‰ˆ)

To make your model been trained stochastically, you need to do the followings:

  • Define a neighborhood sampler.

  • Adapt your model for minibatch training.

  • Modify your training loop.

The following sub-subsections address these steps one by one.

Define a neighborhood sampler and data loader

DGL provides several neighborhood sampler classes that generates the computation dependencies needed for each layer given the nodes we wish to compute on.

The simplest neighborhood sampler is NeighborSampler or the equivalent function-like interface sample_neighbor() which makes the node gather messages from its neighbors.

To use a sampler provided by DGL, one also need to combine it with DataLoader, which iterates over a set of indices (nodes in this case) in minibatches.

For example, the following code creates a DataLoader that iterates over the training node ID set of ogbn-arxiv in batches, putting the list of generated MFGs onto GPU.

import dgl
import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-arxiv").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

Iterating over the DataLoader will yield MiniBatch which contains a list of specially created graphs representing the computation dependencies on each layer. In order to train with DGL, you can access the message flow graphs (MFGs) by calling mini_batch.blocks.

mini_batch = next(iter(dataloader))
print(mini_batch.blocks)

Note

See the Stochastic Training Tutorial for the concept of message flow graph.

If you wish to develop your own neighborhood sampler or you want a more detailed explanation of the concept of MFGs, please refer to 6.4 Implementing Custom Graph Samplers.

Adapt your model for minibatch training

If your message passing modules are all provided by DGL, the changes required to adapt your model to minibatch training is minimal. Take a multi-layer GCN as an example. If your model on full graph is implemented as follows:

class TwoLayerGCN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.conv1 = dglnn.GraphConv(in_features, hidden_features)
        self.conv2 = dglnn.GraphConv(hidden_features, out_features)

    def forward(self, g, x):
        x = F.relu(self.conv1(g, x))
        x = F.relu(self.conv2(g, x))
        return x

Then all you need is to replace g with blocks generated above.

class StochasticTwoLayerGCN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)
        self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)

    def forward(self, blocks, x):
        x = F.relu(self.conv1(blocks[0], x))
        x = F.relu(self.conv2(blocks[1], x))
        return x

The DGL GraphConv modules above accepts an element in blocks generated by the data loader as an argument.

The API reference of each NN module will tell you whether it supports accepting a MFG as an argument.

If you wish to use your own message passing module, please refer to 6.6 Implementing Custom GNN Module for Mini-batch Training.

Training Loop

The training loop simply consists of iterating over the dataset with the customized batching iterator. During each iteration that yields MiniBatch, we:

  1. Access the node features corresponding to the input nodes via data.node_features["feat"]. These features are already moved to the target device (CPU or GPU) by the data loader.

  2. Access the node labels corresponding to the output nodes via data.labels. These labels are already moved to the target device (CPU or GPU) by the data loader.

  3. Feed the list of MFGs and the input node features to the multilayer GNN and get the outputs.

  4. Compute the loss and backpropagate.

model = StochasticTwoLayerGCN(in_features, hidden_features, out_features)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())

for data in dataloader:
    input_features = data.node_features["feat"]
    output_labels = data.labels
    output_predictions = model(data.blocks, input_features)
    loss = compute_loss(output_labels, output_predictions)
    opt.zero_grad()
    loss.backward()
    opt.step()

DGL provides an end-to-end stochastic training example GraphSAGE implementation.

For heterogeneous graphs

Training a graph neural network for node classification on heterogeneous graph is similar.

For instance, we have previously seen how to train a 2-layer RGCN on full graph. The code for RGCN implementation on minibatch training looks very similar to that (with self-loops, non-linearity and basis decomposition removed for simplicity):

class StochasticTwoLayerRGCN(nn.Module):
    def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
                rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
                for rel in rel_names
            })
        self.conv2 = dglnn.HeteroGraphConv({
                rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
                for rel in rel_names
            })

    def forward(self, blocks, x):
        x = self.conv1(blocks[0], x)
        x = self.conv2(blocks[1], x)
        return x

The samplers provided by DGL also support heterogeneous graphs. For example, one can still use the provided NeighborSampler class and DataLoader class for stochastic training. The only difference is that the itemset is now an instance of ItemSetDict which is a dictionary of node types to node IDs.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-mag").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
# For heterogeneous graphs, we need to specify the node feature keys
# for each node type.
datapipe = datapipe.fetch_feature(
    feature, node_feature_keys={"author": ["feat"], "paper": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)

The training loop is almost the same as that of homogeneous graphs, except for the implementation of compute_loss that will take in two dictionaries of node types and predictions here.

model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features, etypes)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())

for data in dataloader:
    # For heterogeneous graphs, we need to specify the node types and
    # feature name when accessing the node features. So does the labels.
    input_features = {
        "author": data.node_features[("author", "feat")],
        "paper": data.node_features[("paper", "feat")]
    }
    output_labels = data.labels["paper"]
    output_predictions = model(data.blocks, input_features)
    loss = compute_loss(output_labels, output_predictions)
    opt.zero_grad()
    loss.backward()
    opt.step()

DGL provides an end-to-end stochastic training example RGCN implementation.