dgl.dataloading

The dgl.dataloading package contains:

  • Data loader classes for iterating over a set of nodes or edges in a graph and generates computation dependency via neighborhood sampling methods.

  • Various sampler classes that perform neighborhood sampling for multi-layer GNNs.

  • Negative samplers for link prediction.

For a holistic explanation on how different components work together. Read the user guide Chapter 6: Stochastic Training on Large Graphs.

Note

This package is experimental and the interfaces may be subject to changes in future releases. It currently only has implementations in PyTorch.

DataLoaders

DGL DataLoader for mini-batch training works similarly to PyTorch’s DataLoader. It has a generator interface that returns mini-batches sampled from some given graphs. DGL provides two DataLoaders: a NodeDataLoader for node classification task and an EdgeDataLoader for edge/link prediction task.

class dgl.dataloading.pytorch.NodeDataLoader(g, nids, block_sampler, device=None, use_ddp=False, ddp_seed=0, **kwargs)[source]

PyTorch dataloader for batch-iterating over a set of nodes, generating the list of message flow graphs (MFGs) as computation dependency of the said minibatch.

Parameters
  • g (DGLGraph) – The graph.

  • nids (Tensor or dict[ntype, Tensor]) – The node set to compute outputs.

  • block_sampler (dgl.dataloading.BlockSampler) – The neighborhood sampler.

  • device (device context, optional) –

    The device of the generated MFGs in each iteration, which should be a PyTorch device object (e.g., torch.device).

    By default this value is the same as the device of g.

  • use_ddp (boolean, optional) –

    If True, tells the DataLoader to split the training set for each participating process appropriately using torch.utils.data.distributed.DistributedSampler.

    Note that set_epoch() must be called at the beginning of every epoch if use_ddp is True.

    Overrides the sampler argument of torch.utils.data.DataLoader.

  • ddp_seed (int, optional) –

    The seed for shuffling the dataset in torch.utils.data.distributed.DistributedSampler.

    Only effective when use_ddp is True.

  • kwargs (dict) – Arguments being passed to torch.utils.data.DataLoader.

Examples

To train a 3-layer GNN for node classification on a set of nodes train_nid on a homogeneous graph where each node takes messages from all neighbors (assume the backend is PyTorch):

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.NodeDataLoader(
...     g, train_nid, sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
...     train_on(input_nodes, output_nodes, blocks)

Using with Distributed Data Parallel

If you are using PyTorch’s distributed training (e.g. when using torch.nn.parallel.DistributedDataParallel), you can train the model by turning on the use_ddp option:

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.NodeDataLoader(
...     g, train_nid, sampler, use_ddp=True,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
...     dataloader.set_epoch(epoch)
...     for input_nodes, output_nodes, blocks in dataloader:
...         train_on(input_nodes, output_nodes, blocks)

Notes

Please refer to Minibatch Training Tutorials and User Guide Section 6 for usage.

Tips for selecting the proper device

  • If the input graph g is on GPU, the output device device must be the same GPU and num_workers must be zero. In this case, the sampling and subgraph construction will take place on the GPU. This is the recommended setting when using a single-GPU and the whole graph fits in GPU memory.

  • If the input graph g is on CPU while the output device device is GPU, then depending on the value of num_workers:

    • If num_workers is set to 0, the sampling will happen on the CPU, and then the subgraphs will be constructed directly on the GPU. This is the recommend setting in multi-GPU configurations.

    • Otherwise, if num_workers is greater than 0, both the sampling and subgraph construction will take place on the CPU. This is the recommended setting when using a single-GPU and the whole graph does not fit in GPU memory.

class dgl.dataloading.pytorch.EdgeDataLoader(g, eids, block_sampler, device='cpu', use_ddp=False, ddp_seed=0, **kwargs)[source]

PyTorch dataloader for batch-iterating over a set of edges, generating the list of message flow graphs (MFGs) as computation dependency of the said minibatch for edge classification, edge regression, and link prediction.

For each iteration, the object will yield

  • A tensor of input nodes necessary for computing the representation on edges, or a dictionary of node type names and such tensors.

  • A subgraph that contains only the edges in the minibatch and their incident nodes. Note that the graph has an identical metagraph with the original graph.

  • If a negative sampler is given, another graph that contains the “negative edges”, connecting the source and destination nodes yielded from the given negative sampler.

  • A list of MFGs necessary for computing the representation of the incident nodes of the edges in the minibatch.

For more details, please refer to 6.2 Training GNN for Edge Classification with Neighborhood Sampling and 6.3 Training GNN for Link Prediction with Neighborhood Sampling.

Parameters
  • g (DGLGraph) – The graph. Currently must be on CPU; GPU is not supported.

  • eids (Tensor or dict[etype, Tensor]) – The edge set in graph g to compute outputs.

  • block_sampler (dgl.dataloading.BlockSampler) – The neighborhood sampler.

  • device (device context, optional) –

    The device of the generated MFGs and graphs in each iteration, which should be a PyTorch device object (e.g., torch.device).

    By default this value is the same as the device of g.

  • g_sampling (DGLGraph, optional) –

    The graph where neighborhood sampling is performed.

    One may wish to iterate over the edges in one graph while perform sampling in another graph. This may be the case for iterating over validation and test edge set while perform neighborhood sampling on the graph formed by only the training edge set.

    If None, assume to be the same as g.

  • exclude (str, optional) –

    Whether and how to exclude dependencies related to the sampled edges in the minibatch. Possible values are

    • None,

    • self,

    • reverse_id,

    • reverse_types

    See the description of the argument with the same name in the docstring of EdgeCollator for more details.

  • reverse_eids (Tensor or dict[etype, Tensor], optional) –

    A tensor of reverse edge ID mapping. The i-th element indicates the ID of the i-th edge’s reverse edge.

    If the graph is heterogeneous, this argument requires a dictionary of edge types and the reverse edge ID mapping tensors.

    See the description of the argument with the same name in the docstring of EdgeCollator for more details.

  • reverse_etypes (dict[etype, etype], optional) –

    The mapping from the original edge types to their reverse edge types.

    See the description of the argument with the same name in the docstring of EdgeCollator for more details.

  • negative_sampler (callable, optional) –

    The negative sampler.

    See the description of the argument with the same name in the docstring of EdgeCollator for more details.

  • use_ddp (boolean, optional) –

    If True, tells the DataLoader to split the training set for each participating process appropriately using torch.utils.data.distributed.DistributedSampler.

    Note that set_epoch() must be called at the beginning of every epoch if use_ddp is True.

    The dataloader will have a dist_sampler attribute to set the epoch number, as recommended by PyTorch.

    Overrides the sampler argument of torch.utils.data.DataLoader.

  • ddp_seed (int, optional) –

    The seed for shuffling the dataset in torch.utils.data.distributed.DistributedSampler.

    Only effective when use_ddp is True.

  • kwargs (dict) – Arguments being passed to torch.utils.data.DataLoader.

Examples

The following example shows how to train a 3-layer GNN for edge classification on a set of edges train_eid on a homogeneous undirected graph. Each node takes messages from all neighbors.

Say that you have an array of source node IDs src and another array of destination node IDs dst. One can make it bidirectional by adding another set of edges that connects from dst to src:

>>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))

One can then know that the ID difference of an edge and its reverse edge is |E|, where |E| is the length of your source/destination array. The reverse edge mapping can be obtained by

>>> E = len(src)
>>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])

Note that the sampled edges as well as their reverse edges are removed from computation dependencies of the incident nodes. That is, the edge will not involve in neighbor sampling and message aggregation. This is a common trick to avoid information leakage.

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
...     g, train_eid, sampler, exclude='reverse_id',
...     reverse_eids=reverse_eids,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, blocks)

To train a 3-layer GNN for link prediction on a set of edges train_eid on a homogeneous graph where each node takes messages from all neighbors (assume the backend is PyTorch), with 5 uniformly chosen negative samples per edge:

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> dataloader = dgl.dataloading.EdgeDataLoader(
...     g, train_eid, sampler, exclude='reverse_id',
...     reverse_eids=reverse_eids, negative_sampler=neg_sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
...     train_on(input_nodse, pair_graph, neg_pair_graph, blocks)

For heterogeneous graphs, the reverse of an edge may have a different edge type from the original edge. For instance, consider that you have an array of user-item clicks, representated by a user array user and an item array item. You may want to build a heterogeneous graph with a user-click-item relation and an item-clicked-by-user relation.

>>> g = dgl.heterograph({
...     ('user', 'click', 'item'): (user, item),
...     ('item', 'clicked-by', 'user'): (item, user)})

To train a 3-layer GNN for edge classification on a set of edges train_eid with type click, you can write

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
...     g, {'click': train_eid}, sampler, exclude='reverse_types',
...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, blocks)

To train a 3-layer GNN for link prediction on a set of edges train_eid with type click, you can write

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> dataloader = dgl.dataloading.EdgeDataLoader(
...     g, train_eid, sampler, exclude='reverse_types',
...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
...     negative_sampler=neg_sampler,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)

Using with Distributed Data Parallel

If you are using PyTorch’s distributed training (e.g. when using torch.nn.parallel.DistributedDataParallel), you can train the model by turning on the use_ddp option:

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> dataloader = dgl.dataloading.EdgeDataLoader(
...     g, train_eid, sampler, use_ddp=True, exclude='reverse_id',
...     reverse_eids=reverse_eids,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
...     dataloader.set_epoch(epoch)
...     for input_nodes, pair_graph, blocks in dataloader:
...         train_on(input_nodes, pair_graph, blocks)

See also

dgl.dataloading.dataloader.EdgeCollator

Notes

Please refer to Minibatch Training Tutorials and User Guide Section 6 for usage.

For end-to-end usages, please refer to the following tutorial/examples:

  • Edge classification on heterogeneous graph: GCMC

  • Link prediction on homogeneous graph: GraphSAGE for unsupervised learning

  • Link prediction on heterogeneous graph: RGCN for link prediction.

class dgl.dataloading.pytorch.GraphDataLoader(dataset, collate_fn=None, use_ddp=False, ddp_seed=0, **kwargs)[source]

PyTorch dataloader for batch-iterating over a set of graphs, generating the batched graph and corresponding label tensor (if provided) of the said minibatch.

Parameters
  • collate_fn (Function, default is None) – The customized collate function. Will use the default collate function if not given.

  • use_ddp (boolean, optional) –

    If True, tells the DataLoader to split the training set for each participating process appropriately using torch.utils.data.distributed.DistributedSampler.

    Overrides the sampler argument of torch.utils.data.DataLoader.

  • ddp_seed (int, optional) –

    The seed for shuffling the dataset in torch.utils.data.distributed.DistributedSampler.

    Only effective when use_ddp is True.

  • kwargs (dict) – Arguments being passed to torch.utils.data.DataLoader.

Examples

To train a GNN for graph classification on a set of graphs in dataset (assume the backend is PyTorch):

>>> dataloader = dgl.dataloading.GraphDataLoader(
...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
...     train_on(batched_graph, labels)

Using with Distributed Data Parallel

If you are using PyTorch’s distributed training (e.g. when using torch.nn.parallel.DistributedDataParallel), you can train the model by turning on the use_ddp option:

>>> dataloader = dgl.dataloading.GraphDataLoader(
...     dataset, use_ddp=True, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for epoch in range(start_epoch, n_epochs):
...     dataloader.set_epoch(epoch)
...     for batched_graph, labels in dataloader:
...         train_on(batched_graph, labels)

Neighbor Sampler

Neighbor samplers are classes that control the behavior of DataLoader s to sample neighbors. All of them inherit the base BlockSampler class, but implement different neighbor sampling strategies by overriding the sample_frontier or the sample_blocks methods.

class dgl.dataloading.neighbor.BlockSampler(num_layers, return_eids=False, output_ctx=None)[source]

Abstract class specifying the neighborhood sampling strategy for DGL data loaders.

The main method for BlockSampler is sample_blocks(), which generates a list of message flow graphs (MFGs) for a multi-layer GNN given a set of seed nodes to have their outputs computed.

The default implementation of sample_blocks() is to repeat num_layers times the following procedure from the last layer to the first layer:

  • Obtain a frontier. The frontier is defined as a graph with the same nodes as the original graph but only the edges involved in message passing on the current layer. Customizable via sample_frontier().

  • Optionally, if the task is link prediction or edge classfication, remove edges connecting training node pairs. If the graph is undirected, also remove the reverse edges. This is controlled by the argument exclude_eids in sample_blocks() method.

  • Convert the frontier into a MFG.

  • Optionally assign the IDs of the edges in the original graph selected in the first step to the MFG, controlled by the argument return_eids in sample_blocks() method.

  • Prepend the MFG to the MFG list to be returned.

All subclasses should override sample_frontier() method while specifying the number of layers to sample in num_layers argument.

Parameters
  • num_layers (int) – The number of layers to sample.

  • return_eids (bool, default False) – Whether to return the edge IDs involved in message passing in the MFG. If True, the edge IDs will be stored as an edge feature named dgl.EID.

  • output_ctx (DGLContext, default None) – The context the sampled blocks will be stored on. This should only be a CUDA context if multiprocessing is not used in the dataloader (e.g., num_workers is 0). If this is None, the sampled blocks will be stored on the same device as the input graph.

Notes

For the concept of frontiers and MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.

sample_blocks(g, seed_nodes, exclude_eids=None)[source]

Generate the a list of MFGs given the destination nodes.

Parameters
  • g (DGLGraph) – The original graph.

  • seed_nodes (Tensor or dict[ntype, Tensor]) –

    The destination nodes by node type.

    If the graph only has one node type, one can just specify a single tensor of node IDs.

  • exclude_eids (Tensor or dict[etype, Tensor]) – The edges to exclude from computation dependency.

Returns

The MFGs generated for computing the multi-layer GNN output.

Return type

list[DGLGraph]

Notes

For the concept of frontiers and MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.

sample_frontier(block_id, g, seed_nodes)[source]

Generate the frontier given the destination nodes.

The subclasses should override this function.

Parameters
  • block_id (int) – Represents which GNN layer the frontier is generated for.

  • g (DGLGraph) – The original graph.

  • seed_nodes (Tensor or dict[ntype, Tensor]) –

    The destination nodes by node type.

    If the graph only has one node type, one can just specify a single tensor of node IDs.

Returns

The frontier generated for the current layer.

Return type

DGLGraph

Notes

For the concept of frontiers and MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.

class dgl.dataloading.neighbor.MultiLayerNeighborSampler(fanouts, replace=False, return_eids=False)[source]

Bases: dgl.dataloading.dataloader.BlockSampler

Sampler that builds computational dependency of node representations via neighbor sampling for multilayer GNN.

This sampler will make every node gather messages from a fixed number of neighbors per edge type. The neighbors are picked uniformly.

Parameters
  • fanouts (list[int] or list[dict[etype, int] or None]) –

    List of neighbors to sample per edge type for each GNN layer, starting from the first layer.

    If the graph is homogeneous, only an integer is needed for each layer.

    If None is provided for one layer, all neighbors will be included regardless of edge types.

    If -1 is provided for one edge type on one layer, then all inbound edges of that edge type will be included.

  • replace (bool, default True) – Whether to sample with replacement

  • return_eids (bool, default False) – Whether to return the edge IDs involved in message passing in the MFG. If True, the edge IDs will be stored as an edge feature named dgl.EID.

Examples

To train a 3-layer GNN for node classification on a set of nodes train_nid on a homogeneous graph where each node takes messages from 5, 10, 15 neighbors for the first, second, and third layer respectively (assuming the backend is PyTorch):

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15])
>>> collator = dgl.dataloading.NodeCollator(g, train_nid, sampler)
>>> dataloader = torch.utils.data.DataLoader(
...     collator.dataset, collate_fn=collator.collate,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for blocks in dataloader:
...     train_on(blocks)

If training on a heterogeneous graph and you want different number of neighbors for each edge type, one should instead provide a list of dicts. Each dict would specify the number of neighbors to pick per edge type.

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([
...     {('user', 'follows', 'user'): 5,
...      ('user', 'plays', 'game'): 4,
...      ('game', 'played-by', 'user'): 3}] * 3)

Notes

For the concept of MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.

sample_frontier(block_id, g, seed_nodes)[source]

Generate the frontier given the destination nodes.

The subclasses should override this function.

Parameters
  • block_id (int) – Represents which GNN layer the frontier is generated for.

  • g (DGLGraph) – The original graph.

  • seed_nodes (Tensor or dict[ntype, Tensor]) –

    The destination nodes by node type.

    If the graph only has one node type, one can just specify a single tensor of node IDs.

Returns

The frontier generated for the current layer.

Return type

DGLGraph

Notes

For the concept of frontiers and MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.

class dgl.dataloading.neighbor.MultiLayerFullNeighborSampler(n_layers, return_eids=False)[source]

Bases: dgl.dataloading.neighbor.MultiLayerNeighborSampler

Sampler that builds computational dependency of node representations by taking messages from all neighbors for multilayer GNN.

This sampler will make every node gather messages from every single neighbor per edge type.

Parameters
  • n_layers (int) – The number of GNN layers to sample.

  • return_eids (bool, default False) – Whether to return the edge IDs involved in message passing in the MFG. If True, the edge IDs will be stored as an edge feature named dgl.EID.

Examples

To train a 3-layer GNN for node classification on a set of nodes train_nid on a homogeneous graph where each node takes messages from all neighbors for the first, second, and third layer respectively (assuming the backend is PyTorch):

>>> sampler = dgl.dataloading.MultiLayerFullNeighborSampler(3)
>>> collator = dgl.dataloading.NodeCollator(g, train_nid, sampler)
>>> dataloader = torch.utils.data.DataLoader(
...     collator.dataset, collate_fn=collator.collate,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for blocks in dataloader:
...     train_on(blocks)

Notes

For the concept of MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.

Collators

Collators are platform-agnostic classes that generates the mini-batches given the graphs and indices to sample from.

class dgl.dataloading.NodeCollator(g, nids, block_sampler)[source]

DGL collator to combine nodes and their computation dependencies within a minibatch for training node classification or regression on a single graph with neighborhood sampling.

Parameters
  • g (DGLGraph) – The graph.

  • nids (Tensor or dict[ntype, Tensor]) – The node set to compute outputs.

  • block_sampler (dgl.dataloading.BlockSampler) – The neighborhood sampler.

Examples

To train a 3-layer GNN for node classification on a set of nodes train_nid on a homogeneous graph where each node takes messages from all neighbors (assume the backend is PyTorch):

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> collator = dgl.dataloading.NodeCollator(g, train_nid, sampler)
>>> dataloader = torch.utils.data.DataLoader(
...     collator.dataset, collate_fn=collator.collate,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, output_nodes, blocks in dataloader:
...     train_on(input_nodes, output_nodes, blocks)

Notes

For the concept of MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.

class dgl.dataloading.EdgeCollator(g, eids, block_sampler, g_sampling=None, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None)[source]

DGL collator to combine edges and their computation dependencies within a minibatch for training edge classification, edge regression, or link prediction on a single graph with neighborhood sampling.

Given a set of edges, the collate function will yield

  • A tensor of input nodes necessary for computing the representation on edges, or a dictionary of node type names and such tensors.

  • A subgraph that contains only the edges in the minibatch and their incident nodes. Note that the graph has an identical metagraph with the original graph.

  • If a negative sampler is given, another graph that contains the “negative edges”, connecting the source and destination nodes yielded from the given negative sampler.

  • A list of MFGs necessary for computing the representation of the incident nodes of the edges in the minibatch.

Parameters
  • g (DGLGraph) – The graph from which the edges are iterated in minibatches and the subgraphs are generated.

  • eids (Tensor or dict[etype, Tensor]) – The edge set in graph g to compute outputs.

  • block_sampler (dgl.dataloading.BlockSampler) – The neighborhood sampler.

  • g_sampling (DGLGraph, optional) –

    The graph where neighborhood sampling and message passing is performed.

    Note that this is not necessarily the same as g.

    If None, assume to be the same as g.

  • exclude (str, optional) –

    Whether and how to exclude dependencies related to the sampled edges in the minibatch. Possible values are

    • None, which excludes nothing.

    • 'self', which excludes the sampled edges themselves but nothing else.

    • 'reverse_id', which excludes the reverse edges of the sampled edges. The said reverse edges have the same edge type as the sampled edges. Only works on edge types whose source node type is the same as its destination node type.

    • 'reverse_types', which excludes the reverse edges of the sampled edges. The said reverse edges have different edge types from the sampled edges.

    If g_sampling is given, exclude is ignored and will be always None.

  • reverse_eids (Tensor or dict[etype, Tensor], optional) –

    A tensor of reverse edge ID mapping. The i-th element indicates the ID of the i-th edge’s reverse edge.

    If the graph is heterogeneous, this argument requires a dictionary of edge types and the reverse edge ID mapping tensors.

    Required and only used when exclude is set to reverse_id.

    For heterogeneous graph this will be a dict of edge type and edge IDs. Note that only the edge types whose source node type is the same as destination node type are needed.

  • reverse_etypes (dict[etype, etype], optional) –

    The mapping from the edge type to its reverse edge type.

    Required and only used when exclude is set to reverse_types.

  • negative_sampler (callable, optional) –

    The negative sampler. Can be omitted if no negative sampling is needed.

    The negative sampler must be a callable that takes in the following arguments:

    • The original (heterogeneous) graph.

    • The ID array of sampled edges in the minibatch, or the dictionary of edge types and ID array of sampled edges in the minibatch if the graph is heterogeneous.

    It should return

    • A pair of source and destination node ID arrays as negative samples, or a dictionary of edge types and such pairs if the graph is heterogenenous.

    A set of builtin negative samplers are provided in the negative sampling module.

Examples

The following example shows how to train a 3-layer GNN for edge classification on a set of edges train_eid on a homogeneous undirected graph. Each node takes messages from all neighbors.

Say that you have an array of source node IDs src and another array of destination node IDs dst. One can make it bidirectional by adding another set of edges that connects from dst to src:

>>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))

One can then know that the ID difference of an edge and its reverse edge is |E|, where |E| is the length of your source/destination array. The reverse edge mapping can be obtained by

>>> E = len(src)
>>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])

Note that the sampled edges as well as their reverse edges are removed from computation dependencies of the incident nodes. This is a common trick to avoid information leakage.

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> collator = dgl.dataloading.EdgeCollator(
...     g, train_eid, sampler, exclude='reverse_id',
...     reverse_eids=reverse_eids)
>>> dataloader = torch.utils.data.DataLoader(
...     collator.dataset, collate_fn=collator.collate,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, blocks)

To train a 3-layer GNN for link prediction on a set of edges train_eid on a homogeneous graph where each node takes messages from all neighbors (assume the backend is PyTorch), with 5 uniformly chosen negative samples per edge:

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> collator = dgl.dataloading.EdgeCollator(
...     g, train_eid, sampler, exclude='reverse_id',
...     reverse_eids=reverse_eids, negative_sampler=neg_sampler)
>>> dataloader = torch.utils.data.DataLoader(
...     collator.dataset, collate_fn=collator.collate,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
...     train_on(input_nodse, pair_graph, neg_pair_graph, blocks)

For heterogeneous graphs, the reverse of an edge may have a different edge type from the original edge. For instance, consider that you have an array of user-item clicks, representated by a user array user and an item array item. You may want to build a heterogeneous graph with a user-click-item relation and an item-clicked-by-user relation.

>>> g = dgl.heterograph({
...     ('user', 'click', 'item'): (user, item),
...     ('item', 'clicked-by', 'user'): (item, user)})

To train a 3-layer GNN for edge classification on a set of edges train_eid with type click, you can write

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> collator = dgl.dataloading.EdgeCollator(
...     g, {'click': train_eid}, sampler, exclude='reverse_types',
...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'})
>>> dataloader = torch.utils.data.DataLoader(
...     collator.dataset, collate_fn=collator.collate,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, blocks)

To train a 3-layer GNN for link prediction on a set of edges train_eid with type click, you can write

>>> sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5)
>>> collator = dgl.dataloading.EdgeCollator(
...     g, train_eid, sampler, exclude='reverse_types',
...     reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'},
...     negative_sampler=neg_sampler)
>>> dataloader = torch.utils.data.DataLoader(
...     collator.dataset, collate_fn=collator.collate,
...     batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
...     train_on(input_nodes, pair_graph, neg_pair_graph, blocks)

Notes

For the concept of MFGs, please refer to User Guide Section 6 and Minibatch Training Tutorials.

class dgl.dataloading.GraphCollator[source]

Given a set of graphs as well as their graph-level data, the collate function will batch the graphs into a batched graph, and stack the tensors into a single bigger tensor. If the example is a container (such as sequences or mapping), the collate function preserves the structure and collates each of the elements recursively.

If the set of graphs has no graph-level data, the collate function will yield a batched graph.

Examples

To train a GNN for graph classification on a set of graphs in dataset (assume the backend is PyTorch):

>>> dataloader = dgl.dataloading.GraphDataLoader(
...     dataset, batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
>>> for batched_graph, labels in dataloader:
...     train_on(batched_graph, labels)

Async Copying to/from GPUs

Data can be copied from the CPU to the GPU while the GPU is being used for computation, using the AsyncTransferer. For the transfer to be fully asynchronous, the context the AsyncTranserer is created with must be a GPU context, and the input tensor must be in pinned memory.

class dgl.dataloading.AsyncTransferer(device)[source]

Class for initiating asynchronous copies to the GPU on a second GPU stream.

To initiate a transfer to a GPU:

>>> tensor_cpu = torch.ones(100000).pin_memory()
>>> transferer = dgl.dataloading.AsyncTransferer(torch.device(0))
>>> future = transferer.async_copy(tensor_cpu, torch.device(0))

And then to wait for the transfer to finish and get a copy of the tensor on the GPU.

>>> tensor_gpu = future.wait()
__init__(device)[source]

Create a new AsyncTransferer object.

Parameters

device (Device or context object.) – The context in which the second stream will be created. Must be a GPU context for the copy to be asynchronous.

async_copy(tensor, device)[source]

Initiate an asynchronous copy on the internal stream. For this call to be asynchronous, the context the AsyncTranserer is created with must be a GPU context, and the input tensor must be in pinned memory.

Currently, only transfers to the GPU are supported.

Parameters
  • tensor (Tensor) – The tensor to transfer.

  • device (Device or context object.) – The context to transfer to.

Returns

A Transfer object that can be waited on to get the tensor in the new context.

Return type

Transfer