SubgraphSampler

class dgl.graphbolt.SubgraphSampler(datapipe, *args, **kwargs)[source]

Bases: MiniBatchTransformer

A subgraph sampler used to sample a subgraph from a given set of nodes from a larger graph.

Functional name: sample_subgraph.

This class is the base class of all subgraph samplers. Any subclass of SubgraphSampler should implement either the sample_subgraphs() method or the sampling_stages() method to define the fine-grained sampling stages to take advantage of optimizations provided by the GraphBolt DataLoader.

Parameters:
  • datapipe (DataPipe) – The datapipe.

  • args (Non-Keyword Arguments) – Arguments to be passed into sampling_stages.

  • kwargs (Keyword Arguments) – Arguments to be passed into sampling_stages.

sample_subgraphs(seeds, seeds_timestamp)[source]

Sample subgraphs from the given seeds, possibly with temporal constraints.

Any subclass of SubgraphSampler should implement this method.

Parameters:
  • seeds (Union[torch.Tensor, Dict[str, torch.Tensor]]) – The seed nodes.

  • seeds_timestamp (Union[torch.Tensor, Dict[str, torch.Tensor]]) – The timestamps of the seed nodes. If given, the sampled subgraphs should not contain any nodes or edges that are newer than the timestamps of the seed nodes. Default: None.

Returns:

  • Union[torch.Tensor, Dict[str, torch.Tensor]] – The input nodes.

  • List[SampledSubgraph] – The sampled subgraphs.

Examples

>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>>     def __init__(self, datapipe, graph, fanouts):
>>>         super().__init__(datapipe)
>>>         self.graph = graph
>>>         self.fanouts = fanouts
>>>     def sample_subgraphs(self, seeds):
>>>         # Sample subgraphs from the given seeds.
>>>         subgraphs = []
>>>         subgraphs_nodes = []
>>>         for fanout in reversed(self.fanouts):
>>>             subgraph = self.graph.sample_neighbors(seeds, fanout)
>>>             subgraphs.insert(0, subgraph)
>>>             subgraphs_nodes.append(subgraph.nodes)
>>>             seeds = subgraph.nodes
>>>         subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>>         return subgraphs_nodes, subgraphs
sampling_stages(datapipe)[source]

The sampling stages are defined here by chaining to the datapipe. The default implementation expects sample_subgraphs() to be implemented. To define fine-grained stages, this method should be overridden.