"""Module for message propagation."""
from __future__ import absolute_import
from . import backend as F, traversal as trv
from .heterograph import DGLGraph
__all__ = [
"prop_nodes",
"prop_nodes_bfs",
"prop_nodes_topo",
"prop_edges",
"prop_edges_dfs",
]
[docs]def prop_nodes(
graph,
nodes_generator,
message_func="default",
reduce_func="default",
apply_node_func="default",
):
"""Functional method for :func:`dgl.DGLGraph.prop_nodes`.
Parameters
----------
node_generators : generator
The generator of node frontiers.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.DGLGraph.prop_nodes
"""
graph.prop_nodes(
nodes_generator, message_func, reduce_func, apply_node_func
)
[docs]def prop_edges(
graph,
edges_generator,
message_func="default",
reduce_func="default",
apply_node_func="default",
):
"""Functional method for :func:`dgl.DGLGraph.prop_edges`.
Parameters
----------
edges_generator : generator
The generator of edge frontiers.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.DGLGraph.prop_edges
"""
graph.prop_edges(
edges_generator, message_func, reduce_func, apply_node_func
)
[docs]def prop_nodes_bfs(
graph,
source,
message_func,
reduce_func,
reverse=False,
apply_node_func=None,
):
"""Message propagation using node frontiers generated by BFS.
Parameters
----------
graph : DGLGraph
The graph object.
source : list, tensor of nodes
Source nodes.
message_func : callable
The message function.
reduce_func : callable
The reduce function.
reverse : bool, optional
If true, traverse following the in-edge direction.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.traversal.bfs_nodes_generator
"""
assert isinstance(
graph, DGLGraph
), "DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph"
assert (
len(graph.canonical_etypes) == 1
), "prop_nodes_bfs only support homogeneous graph"
# TODO(murphy): Graph traversal currently is only supported on
# CPP graphs. Move graph to CPU as a workaround,
# which should be fixed in the future.
nodes_gen = trv.bfs_nodes_generator(graph.cpu(), source, reverse)
nodes_gen = [F.copy_to(frontier, graph.device) for frontier in nodes_gen]
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
[docs]def prop_nodes_topo(
graph, message_func, reduce_func, reverse=False, apply_node_func=None
):
"""Message propagation using node frontiers generated by topological order.
Parameters
----------
graph : DGLGraph
The graph object.
message_func : callable
The message function.
reduce_func : callable
The reduce function.
reverse : bool, optional
If true, traverse following the in-edge direction.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.traversal.topological_nodes_generator
"""
assert isinstance(
graph, DGLGraph
), "DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph"
assert (
len(graph.canonical_etypes) == 1
), "prop_nodes_topo only support homogeneous graph"
# TODO(murphy): Graph traversal currently is only supported on
# CPP graphs. Move graph to CPU as a workaround,
# which should be fixed in the future.
nodes_gen = trv.topological_nodes_generator(graph.cpu(), reverse)
nodes_gen = [F.copy_to(frontier, graph.device) for frontier in nodes_gen]
prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
[docs]def prop_edges_dfs(
graph,
source,
message_func,
reduce_func,
reverse=False,
has_reverse_edge=False,
has_nontree_edge=False,
apply_node_func=None,
):
"""Message propagation using edge frontiers generated by labeled DFS.
Parameters
----------
graph : DGLGraph
The graph object.
source : list, tensor of nodes
Source nodes.
message_func : callable, optional
The message function.
reduce_func : callable, optional
The reduce function.
reverse : bool, optional
If true, traverse following the in-edge direction.
has_reverse_edge : bool, optional
If true, REVERSE edges are included.
has_nontree_edge : bool, optional
If true, NONTREE edges are included.
apply_node_func : callable, optional
The update function.
See Also
--------
dgl.traversal.dfs_labeled_edges_generator
"""
assert isinstance(
graph, DGLGraph
), "DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph"
assert (
len(graph.canonical_etypes) == 1
), "prop_edges_dfs only support homogeneous graph"
# TODO(murphy): Graph traversal currently is only supported on
# CPP graphs. Move graph to CPU as a workaround,
# which should be fixed in the future.
edges_gen = trv.dfs_labeled_edges_generator(
graph.cpu(),
source,
reverse,
has_reverse_edge,
has_nontree_edge,
return_labels=False,
)
edges_gen = [F.copy_to(frontier, graph.device) for frontier in edges_gen]
prop_edges(graph, edges_gen, message_func, reduce_func, apply_node_func)