Source code for dgl.propagate

"""Module for message propagation."""
from __future__ import absolute_import

from . import backend as F, traversal as trv
from .heterograph import DGLGraph

__all__ = [
    "prop_nodes",
    "prop_nodes_bfs",
    "prop_nodes_topo",
    "prop_edges",
    "prop_edges_dfs",
]


[docs]def prop_nodes( graph, nodes_generator, message_func="default", reduce_func="default", apply_node_func="default", ): """Functional method for :func:`dgl.DGLGraph.prop_nodes`. Parameters ---------- node_generators : generator The generator of node frontiers. message_func : callable, optional The message function. reduce_func : callable, optional The reduce function. apply_node_func : callable, optional The update function. See Also -------- dgl.DGLGraph.prop_nodes """ graph.prop_nodes( nodes_generator, message_func, reduce_func, apply_node_func )
[docs]def prop_edges( graph, edges_generator, message_func="default", reduce_func="default", apply_node_func="default", ): """Functional method for :func:`dgl.DGLGraph.prop_edges`. Parameters ---------- edges_generator : generator The generator of edge frontiers. message_func : callable, optional The message function. reduce_func : callable, optional The reduce function. apply_node_func : callable, optional The update function. See Also -------- dgl.DGLGraph.prop_edges """ graph.prop_edges( edges_generator, message_func, reduce_func, apply_node_func )
[docs]def prop_nodes_bfs( graph, source, message_func, reduce_func, reverse=False, apply_node_func=None, ): """Message propagation using node frontiers generated by BFS. Parameters ---------- graph : DGLGraph The graph object. source : list, tensor of nodes Source nodes. message_func : callable The message function. reduce_func : callable The reduce function. reverse : bool, optional If true, traverse following the in-edge direction. apply_node_func : callable, optional The update function. See Also -------- dgl.traversal.bfs_nodes_generator """ assert isinstance( graph, DGLGraph ), "DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph" assert ( len(graph.canonical_etypes) == 1 ), "prop_nodes_bfs only support homogeneous graph" # TODO(murphy): Graph traversal currently is only supported on # CPP graphs. Move graph to CPU as a workaround, # which should be fixed in the future. nodes_gen = trv.bfs_nodes_generator(graph.cpu(), source, reverse) nodes_gen = [F.copy_to(frontier, graph.device) for frontier in nodes_gen] prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
[docs]def prop_nodes_topo( graph, message_func, reduce_func, reverse=False, apply_node_func=None ): """Message propagation using node frontiers generated by topological order. Parameters ---------- graph : DGLGraph The graph object. message_func : callable The message function. reduce_func : callable The reduce function. reverse : bool, optional If true, traverse following the in-edge direction. apply_node_func : callable, optional The update function. See Also -------- dgl.traversal.topological_nodes_generator """ assert isinstance( graph, DGLGraph ), "DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph" assert ( len(graph.canonical_etypes) == 1 ), "prop_nodes_topo only support homogeneous graph" # TODO(murphy): Graph traversal currently is only supported on # CPP graphs. Move graph to CPU as a workaround, # which should be fixed in the future. nodes_gen = trv.topological_nodes_generator(graph.cpu(), reverse) nodes_gen = [F.copy_to(frontier, graph.device) for frontier in nodes_gen] prop_nodes(graph, nodes_gen, message_func, reduce_func, apply_node_func)
[docs]def prop_edges_dfs( graph, source, message_func, reduce_func, reverse=False, has_reverse_edge=False, has_nontree_edge=False, apply_node_func=None, ): """Message propagation using edge frontiers generated by labeled DFS. Parameters ---------- graph : DGLGraph The graph object. source : list, tensor of nodes Source nodes. message_func : callable, optional The message function. reduce_func : callable, optional The reduce function. reverse : bool, optional If true, traverse following the in-edge direction. has_reverse_edge : bool, optional If true, REVERSE edges are included. has_nontree_edge : bool, optional If true, NONTREE edges are included. apply_node_func : callable, optional The update function. See Also -------- dgl.traversal.dfs_labeled_edges_generator """ assert isinstance( graph, DGLGraph ), "DGLHeteroGraph is merged with DGLGraph, Please use DGLGraph" assert ( len(graph.canonical_etypes) == 1 ), "prop_edges_dfs only support homogeneous graph" # TODO(murphy): Graph traversal currently is only supported on # CPP graphs. Move graph to CPU as a workaround, # which should be fixed in the future. edges_gen = trv.dfs_labeled_edges_generator( graph.cpu(), source, reverse, has_reverse_edge, has_nontree_edge, return_labels=False, ) edges_gen = [F.copy_to(frontier, graph.device) for frontier in edges_gen] prop_edges(graph, edges_gen, message_func, reduce_func, apply_node_func)