dgl.DGLGraph.filter_nodes

DGLGraph.filter_nodes(predicate, nodes='__ALL__')[source]

Return a tensor of node IDs that satisfy the given predicate.

Parameters:
  • predicate (callable) – A function of signature func(nodes) -> tensor. nodes are NodeBatch objects as in udf. The tensor returned should be a 1-D boolean tensor with each element indicating whether the corresponding node in the batch satisfies the predicate.
  • nodes (int, iterable or tensor of ints) – The nodes to filter on. Default value is all the nodes.
Returns:

The filtered nodes.

Return type:

tensor

Examples

Construct a graph object for demo.

Note

Here we use pytorch syntax for demo. The general idea applies to other frameworks with minor syntax change (e.g. replace torch.tensor with mxnet.ndarray).

>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.ndata['x'] = th.tensor([[1.], [-1.], [1.]])

Define a function for filtering nodes with feature \(1\).

>>> def has_feature_one(nodes): return (nodes.data['x'] == 1).squeeze(1)

Filter the nodes with feature \(1\).

>>> g.filter_nodes(has_feature_one)
tensor([0, 2])

See also

filter_edges()