6.4 Implementing Custom Graph Samplersο
Implementing custom samplers involves subclassing the
dgl.graphbolt.SubgraphSampler
base class and implementing its abstract
sample_subgraphs
method. The sample_subgraphs
method should
take in seed nodes which are the nodes to sample neighbors from:
def sample_subgraphs(self, seed_nodes):
return input_nodes, sampled_subgraphs
The method should return the input node IDs list and a list of subgraphs. Each
subgraph is a SampledSubgraph
object.
Any other data that are required during sampling such as the graph structure, fanout size, etc. should be passed to the sampler via the constructor.
The code below implements a classical neighbor sampler:
@functional_datapipe("customized_sample_neighbor")
class CustomizedNeighborSampler(dgl.graphbolt.SubgraphSampler):
def __init__(self, datapipe, graph, fanouts):
super().__init__(datapipe)
self.graph = graph
self.fanouts = fanouts
def sample_subgraphs(self, seed_nodes):
subgs = []
for fanout in reversed(self.fanouts):
# Sample a fixed number of neighbors of the current seed nodes.
input_nodes, sg = g.sample_neighbors(seed_nodes, fanout)
subgs.insert(0, sg)
seed_nodes = input_nodes
return input_nodes, subgs
To use this sampler with DataLoader
:
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
for data in dataloader:
input_features = data.node_features["feat"]
output_labels = data.labels
output_predictions = model(data.blocks, input_features)
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
Sampler for Heterogeneous Graphsο
To write a sampler for heterogeneous graphs, one needs to be aware that
the argument graph is a heterogeneous graph while seeds could be a
dictionary of ID tensors. Most of DGLβs graph sampling operators (e.g.,
the sample_neighbors
and to_block
functions in the above example) can
work on heterogeneous graph natively, so many samplers are automatically
ready for heterogeneous graph. For example, the above CustomizedNeighborSampler
can be used on heterogeneous graphs:
import dgl.graphbolt as gb
hg = gb.FusedCSCSamplingGraph()
train_set = item_set = gb.ItemSetDict(
{
"user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"),
),
"item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_nodes", "labels"),
),
}
)
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(
feature, node_feature_keys={"user": ["feat"], "item": ["feat"]}
)
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe)
for data in dataloader:
input_features = {
ntype: data.node_features[(ntype, "feat")]
for ntype in data.blocks[0].srctypes
}
output_labels = data.labels["user"]
output_predictions = model(data.blocks, input_features)["user"]
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
Exclude Edges After Samplingο
In some cases, we may want to exclude seed edges from the sampled subgraph. For example, in link prediction tasks, we want to exclude the edges in the training set from the sampled subgraph to prevent information leakage. To do so, we need to add an additional datapipe right after sampling as follows:
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.transform(gb.exclude_seed_edges)
Please check the API page of exclude_seed_edges()
for more
details.
The above API is based on exclude_edges()
.
If you want to exclude edges from the sampled subgraph based on some other
criteria, you could write your own transform function. Please check the method
for reference.
You could also refer to examples in Link Prediction.