SAINTSampler

class dgl.dataloading.SAINTSampler(mode, budget, cache=True, prefetch_ndata=None, prefetch_edata=None, output_device='cpu')[source]

Bases: dgl.dataloading.base.Sampler

Random node/edge/walk sampler from GraphSAINT: Graph Sampling Based Inductive Learning Method

For each call, the sampler samples a node subset and then returns a node induced subgraph. There are three options for sampling node subsets:

  • For 'node' sampler, the probability to sample a node is in proportion to its out-degree.

  • The 'edge' sampler first samples an edge subset and then use the end nodes of the edges.

  • The 'walk' sampler uses the nodes visited by random walks. It uniformly selects a number of root nodes and then performs a fixed-length random walk from each root node.

Parameters
  • mode (str) – The sampler to use, which can be 'node', 'edge', or 'walk'.

  • budget (int or tuple[int]) –

    Sampler configuration.

    • For 'node' sampler, budget specifies the number of nodes in each sampled subgraph.

    • For 'edge' sampler, budget specifies the number of edges to sample for inducing a subgraph.

    • For 'walk' sampler, budget is a tuple. budget[0] specifies the number of root nodes to generate random walks. budget[1] specifies the length of a random walk.

  • cache (bool, optional) – If False, it will not cache the probability arrays for sampling. Setting it to False is required if you want to use the sampler across different graphs.

  • prefetch_ndata (list[str], optional) –

    The node data to prefetch for the subgraph.

    See 6.8 Feature Prefetching for a detailed explanation of prefetching.

  • prefetch_edata (list[str], optional) –

    The edge data to prefetch for the subgraph.

    See 6.8 Feature Prefetching for a detailed explanation of prefetching.

  • output_device (device, optional) – The device of the output subgraphs.

Examples

>>> import torch
>>> from dgl.dataloading import SAINTSampler, DataLoader
>>> num_iters = 1000
>>> sampler = SAINTSampler(mode='node', budget=6000)
>>> # Assume g.ndata['feat'] and g.ndata['label'] hold node features and labels
>>> dataloader = DataLoader(g, torch.arange(num_iters), sampler, num_workers=4)
>>> for subg in dataloader:
...     train_on(subg)