class dgl.dataloading.NeighborSampler(fanouts, edge_dir='in', prob=None, replace=False, prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, output_device=None)[source]

Sampler that builds computational dependency of node representations via neighbor sampling for multilayer GNN.

This sampler will make every node gather messages from a fixed number of neighbors per edge type. The neighbors are picked uniformly.

Parameters
• fanouts (list[int] or list[dict[etype, int]]) –

List of neighbors to sample per edge type for each GNN layer, with the i-th element being the fanout for the i-th GNN layer.

If only a single integer is provided, DGL assumes that every edge type will have the same fanout.

If -1 is provided for one edge type on one layer, then all inbound edges of that edge type will be included.

• edge_dir (str, default 'in') – Can be either 'in'  where the neighbors will be sampled according to incoming edges, or 'out' otherwise, same as dgl.sampling.sample_neighbors().

• prob (str, optional) – If given, the probability of each neighbor being sampled is proportional to the edge feature value with the given name in g.edata. The feature must be a scalar on each edge.

• replace (bool, default False) – Whether to sample with replacement

• prefetch_node_feats (list[str] or dict[ntype, list[str]], optional) – The source node data to prefetch for the first MFG, corresponding to the input node features necessary for the first GNN layer.

• prefetch_labels (list[str] or dict[ntype, list[str]], optional) – The destination node data to prefetch for the last MFG, corresponding to the node labels of the minibatch.

• prefetch_edge_feats (list[str] or dict[etype, list[str]], optional) – The edge data names to prefetch for all the MFGs, corresponding to the edge features necessary for all GNN layers.

• output_device (device, optional) – The device of the output subgraphs or MFGs. Default is the same as the minibatch of seed nodes.

Examples

Node classification

To train a 3-layer GNN for node classification on a set of nodes train_nid on a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for the first, second, and third layer respectively (assuming the backend is PyTorch):

>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15])
...     g, train_nid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
...     train_on(blocks)


If training on a heterogeneous graph and you want different number of neighbors for each edge type, one should instead provide a list of dicts. Each dict would specify the number of neighbors to pick per edge type.

>>> sampler = dgl.dataloading.NeighborSampler([
...     {('user', 'follows', 'user'): 5,
...      ('user', 'plays', 'game'): 4,
...      ('game', 'played-by', 'user'): 3}] * 3)


If you would like non-uniform neighbor sampling:

>>> g.edata['p'] = torch.rand(g.num_edges())   # any non-negative 1D vector works
>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15], prob='p')


Edge classification and link prediction

This class can also work for edge classification and link prediction together with as_edge_prediction_sampler().

>>> sampler = dgl.dataloading.NeighborSampler([5, 10, 15])
...     g, train_eid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)


See the documentation as_edge_prediction_sampler() for more details.

Notes

For the concept of MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.

__init__(fanouts, edge_dir='in', prob=None, replace=False, prefetch_node_feats=None, prefetch_labels=None, prefetch_edge_feats=None, output_device=None)[source]

Initialize self. See help(type(self)) for accurate signature.

Methods

 __init__(fanouts[, edge_dir, prob, replace, …]) Initialize self. assign_lazy_features(result) Assign lazy features for prefetching. sample(g, seed_nodes[, exclude_eids]) Sample a list of blocks from the given seed nodes. sample_blocks(g, seed_nodes[, exclude_eids]) Generates a list of blocks from the given seed nodes.