6.1 Training GNN for Node Classification with Neighborhood Samplingο
To make your model been trained stochastically, you need to do the followings:
Define a neighborhood sampler.
Adapt your model for minibatch training.
Modify your training loop.
The following sub-subsections address these steps one by one.
Define a neighborhood sampler and data loaderο
DGL provides several neighborhood sampler classes that generates the computation dependencies needed for each layer given the nodes we wish to compute on.
The simplest neighborhood sampler is NeighborSampler
or the equivalent function-like interface sample_neighbor()
which makes the node gather messages from its neighbors.
To use a sampler provided by DGL, one also need to combine it with
DataLoader
, which iterates
over a set of indices (nodes in this case) in minibatches.
For example, the following code creates a DataLoader that
iterates over the training node ID set of ogbn-arxiv
in batches,
putting the list of generated MFGs onto GPU.
import dgl
import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-arxiv").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
Iterating over the DataLoader will yield MiniBatch
which contains a list of specially created graphs representing the computation
dependencies on each layer. In order to train with DGL, you can access the
message flow graphs (MFGs) by calling mini_batch.blocks.
mini_batch = next(iter(dataloader))
print(mini_batch.blocks)
Note
See the Stochastic Training Tutorial for the concept of message flow graph.
If you wish to develop your own neighborhood sampler or you want a more detailed explanation of the concept of MFGs, please refer to 6.4 Implementing Custom Graph Samplers.
Adapt your model for minibatch trainingο
If your message passing modules are all provided by DGL, the changes required to adapt your model to minibatch training is minimal. Take a multi-layer GCN as an example. If your model on full graph is implemented as follows:
class TwoLayerGCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.conv1 = dglnn.GraphConv(in_features, hidden_features)
self.conv2 = dglnn.GraphConv(hidden_features, out_features)
def forward(self, g, x):
x = F.relu(self.conv1(g, x))
x = F.relu(self.conv2(g, x))
return x
Then all you need is to replace g
with blocks
generated above.
class StochasticTwoLayerGCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)
self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)
def forward(self, blocks, x):
x = F.relu(self.conv1(blocks[0], x))
x = F.relu(self.conv2(blocks[1], x))
return x
The DGL GraphConv
modules above accepts an element in blocks
generated by the data loader as an argument.
The API reference of each NN module will tell you whether it supports accepting a MFG as an argument.
If you wish to use your own message passing module, please refer to 6.6 Implementing Custom GNN Module for Mini-batch Training.
Training Loopο
The training loop simply consists of iterating over the dataset with the
customized batching iterator. During each iteration that yields
MiniBatch
, we:
Access the node features corresponding to the input nodes via
data.node_features["feat"]
. These features are already moved to the target device (CPU or GPU) by the data loader.Access the node labels corresponding to the output nodes via
data.labels
. These labels are already moved to the target device (CPU or GPU) by the data loader.Feed the list of MFGs and the input node features to the multilayer GNN and get the outputs.
Compute the loss and backpropagate.
model = StochasticTwoLayerGCN(in_features, hidden_features, out_features)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())
for data in dataloader:
input_features = data.node_features["feat"]
output_labels = data.labels
output_predictions = model(data.blocks, input_features)
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
DGL provides an end-to-end stochastic training example GraphSAGE implementation.
For heterogeneous graphsο
Training a graph neural network for node classification on heterogeneous graph is similar.
For instance, we have previously seen how to train a 2-layer RGCN on full graph. The code for RGCN implementation on minibatch training looks very similar to that (with self-loops, non-linearity and basis decomposition removed for simplicity):
class StochasticTwoLayerRGCN(nn.Module):
def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
super().__init__()
self.conv1 = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
for rel in rel_names
})
self.conv2 = dglnn.HeteroGraphConv({
rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
for rel in rel_names
})
def forward(self, blocks, x):
x = self.conv1(blocks[0], x)
x = self.conv2(blocks[1], x)
return x
The samplers provided by DGL also support heterogeneous graphs.
For example, one can still use the provided
NeighborSampler
class and
DataLoader
class for
stochastic training. The only difference is that the itemset is now an
instance of ItemSetDict
which is a dictionary
of node types to node IDs.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-mag").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
# For heterogeneous graphs, we need to specify the node feature keys
# for each node type.
datapipe = datapipe.fetch_feature(
feature, node_feature_keys={"author": ["feat"], "paper": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
The training loop is almost the same as that of homogeneous graphs,
except for the implementation of compute_loss
that will take in two
dictionaries of node types and predictions here.
model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features, etypes)
model = model.to(device)
opt = torch.optim.Adam(model.parameters())
for data in dataloader:
# For heterogeneous graphs, we need to specify the node types and
# feature name when accessing the node features. So does the labels.
input_features = {
"author": data.node_features[("author", "feat")],
"paper": data.node_features[("paper", "feat")]
}
output_labels = data.labels["paper"]
output_predictions = model(data.blocks, input_features)
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
DGL provides an end-to-end stochastic training example RGCN implementation.