dgl.dataloading.as_edge_prediction_sampler¶
-
dgl.dataloading.
as_edge_prediction_sampler
(sampler, exclude=None, reverse_eids=None, reverse_etypes=None, negative_sampler=None, prefetch_labels=None)[source]¶ Create an edge-wise sampler from a node-wise sampler.
For each batch of edges, the sampler applies the provided node-wise sampler to their source and destination nodes to extract subgraphs. It also generates negative edges if a negative sampler is provided, and extract subgraphs for their incident nodes as well.
For each iteration, the sampler will yield
A tensor of input nodes necessary for computing the representation on edges, or a dictionary of node type names and such tensors.
A subgraph that contains only the edges in the minibatch and their incident nodes. Note that the graph has an identical metagraph with the original graph.
If a negative sampler is given, another graph that contains the “negative edges”, connecting the source and destination nodes yielded from the given negative sampler.
The subgraphs or MFGs returned by the provided node-wise sampler, generated from the incident nodes of the edges in the minibatch (as well as those of the negative edges if applicable).
- Parameters
sampler (Sampler) – The node-wise sampler object. It additionally requires that the
sample
method must have an optional third argumentexclude_eids
representing the edge IDs to exclude from neighborhood. The argument will be either a tensor for homogeneous graphs or a dict of edge types and tensors for heterogeneous graphs.exclude (str, optional) –
Whether and how to exclude dependencies related to the sampled edges in the minibatch. Possible values are
None, for not excluding any edges.
self
, for excluding the edges in the current minibatch.reverse_id
, for excluding not only the edges in the current minibatch but also their reverse edges according to the ID mapping in the argumentreverse_eids
.reverse_types
, for excluding not only the edges in the current minibatch but also their reverse edges stored in another type according to the argumentreverse_etypes
.User-defined exclusion rule. It is a callable with edges in the current minibatch as a single argument and should return the edges to be excluded.
reverse_eids (Tensor or dict[etype, Tensor], optional) –
A tensor of reverse edge ID mapping. The i-th element indicates the ID of the i-th edge’s reverse edge.
If the graph is heterogeneous, this argument requires a dictionary of edge types and the reverse edge ID mapping tensors.
reverse_etypes (dict[etype, etype], optional) – The mapping from the original edge types to their reverse edge types.
negative_sampler (callable, optional) – The negative sampler.
prefetch_labels (list[str] or dict[etype, list[str]], optional) –
The edge labels to prefetch for the returned positive pair graph.
See 6.8 Feature Prefetching for a detailed explanation of prefetching.
Examples
The following example shows how to train a 3-layer GNN for edge classification on a set of edges
train_eid
on a homogeneous undirected graph. Each node takes messages from all neighbors.Given an array of source node IDs
src
and another array of destination node IDsdst
, the following code creates a bidirectional graph:>>> g = dgl.graph((torch.cat([src, dst]), torch.cat([dst, src])))
Edge \(i\)’s reverse edge in the graph above is edge \(i + |E|\). Therefore, we can create a reverse edge mapping
reverse_eids
by:>>> E = len(src) >>> reverse_eids = torch.cat([torch.arange(E, 2 * E), torch.arange(0, E)])
By passing
reverse_eids
to the edge sampler, the edges in the current mini-batch and their reversed edges will be excluded from the extracted subgraphs to avoid information leakage.>>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... exclude='reverse_id', reverse_eids=reverse_eids) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_eid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, blocks)
For link prediction, one can provide a negative sampler to sample negative edges. The code below uses DGL’s
Uniform
to generate 5 negative samples per edge:>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) >>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... sampler, exclude='reverse_id', reverse_eids=reverse_eids, ... negative_sampler=neg_sampler) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_eid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)
For heterogeneous graphs, reverse edges may belong to a different relation. For example, the relations “user-click-item” and “item-click-by-user” in the graph below are mutual reverse.
>>> g = dgl.heterograph({ ... ('user', 'click', 'item'): (user, item), ... ('item', 'clicked-by', 'user'): (item, user)})
To correctly exclude edges from each mini-batch, set
exclude='reverse_types'
and pass a dictionary{'click': 'clicked-by', 'clicked-by': 'click'}
to thereverse_etypes
argument.>>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... exclude='reverse_types', ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}) >>> dataloader = dgl.dataloading.DataLoader( ... g, {'click': train_eid}, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, blocks)
For link prediction, provide a negative sampler to generate negative samples:
>>> neg_sampler = dgl.dataloading.negative_sampler.Uniform(5) >>> sampler = dgl.dataloading.as_edge_prediction_sampler( ... dgl.dataloading.NeighborSampler([15, 10, 5]), ... exclude='reverse_types', ... reverse_etypes={'click': 'clicked-by', 'clicked-by': 'click'}, ... negative_sampler=neg_sampler) >>> dataloader = dgl.dataloading.DataLoader( ... g, train_eid, sampler, ... batch_size=1024, shuffle=True, drop_last=False, num_workers=4) >>> for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader: ... train_on(input_nodes, pair_graph, neg_pair_graph, blocks)