dgl.DGLGraph.filter_edges¶
-
DGLGraph.
filter_edges
(predicate, edges='__ALL__', etype=None)[source]¶ Return the IDs of the edges with the given edge type that satisfy the given predicate.
- Parameters
predicate (callable) – A function of signature
func(edges) -> Tensor
.edges
aredgl.EdgeBatch
objects. Its output tensor should be a 1D boolean tensor with each element indicating whether the corresponding edge in the batch satisfies the predicate.edges (edges) –
The edges to send and receive messages on. The allowed input formats are:
int
: A single edge ID.Int Tensor: Each element is an edge ID. The tensor must have the same device type and ID data type as the graph’s.
iterable[int]: Each element is an edge ID.
(Tensor, Tensor): The node-tensors format where the i-th elements of the two tensors specify an edge.
(iterable[int], iterable[int]): Similar to the node-tensors format but stores edge endpoints in python iterables.
By default, it considers all the edges.
etype (str or (str, str, str), optional) –
The type name of the edges. The allowed type name formats are:
(str, str, str)
for source node type, edge type and destination node type.or one
str
edge type name if the name can uniquely identify a triplet format in the graph.
Can be omitted if the graph has only one type of edges.
- Returns
A 1D tensor that contains the ID(s) of the edge(s) that satisfy the predicate.
- Return type
Tensor
Examples
The following example uses PyTorch backend.
>>> import dgl >>> import torch
Define a predicate function.
>>> def edges_with_feature_one(edges): ... # Whether an edge has feature 1 ... return (edges.data['h'] == 1.).squeeze(1)
Filter edges for a homogeneous graph.
>>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) >>> g.edata['h'] = torch.tensor([[0.], [1.], [1.]]) >>> print(g.filter_edges(edges_with_feature_one)) tensor([1, 2])
Filter on edges with IDs 0 and 1
>>> print(g.filter_edges(edges_with_feature_one, edges=torch.tensor([0, 1]))) tensor([1])
Filter edges for a heterogeneous graph.
>>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]), ... torch.tensor([0, 0, 1, 1])), ... ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2]))}) >>> g.edges['plays'].data['h'] = torch.tensor([[0.], [1.], [1.], [0.]]) >>> # Filter for 'plays' nodes >>> print(g.filter_edges(edges_with_feature_one, etype='plays')) tensor([1, 2])