"""Built-in message function."""
from __future__ import absolute_import
import sys
from itertools import product
from .base import BuiltinFunction, TargetCode
from .._deprecate.runtime import ir
from .._deprecate.runtime.ir import var
__all__ = ["src_mul_edge", "copy_src", "copy_edge", "copy_u", "copy_e",
"BinaryMessageFunction", "CopyMessageFunction"]
class MessageFunction(BuiltinFunction):
"""Base builtin message function class."""
def _invoke(self, graph, src_frame, dst_frame, edge_frame, out_size,
src_map, dst_map, edge_map, out_map, reducer="none"):
"""Symbolic computation of this builtin function to create
runtime.executor
"""
raise NotImplementedError
@property
def name(self):
"""Return the name of this builtin function."""
raise NotImplementedError
class BinaryMessageFunction(MessageFunction):
"""Class for the lhs_op_rhs builtin message function.
See Also
--------
src_mul_edge
"""
def __init__(self, binary_op, lhs, rhs, lhs_field, rhs_field, out_field):
self.binary_op = binary_op
self.lhs = lhs
self.rhs = rhs
self.lhs_field = lhs_field
self.rhs_field = rhs_field
self.out_field = out_field
def _invoke(self, graph, src_frame, dst_frame, edge_frame, out_size,
src_map, dst_map, edge_map, out_map, reducer="none"):
"""Symbolic computation of builtin binary message function to create
runtime.executor
"""
graph = var.GRAPH(graph)
in_frames = (src_frame, dst_frame, edge_frame)
in_maps = (src_map, dst_map, edge_map)
lhs_data = ir.READ_COL(in_frames[self.lhs], var.STR(self.lhs_field))
rhs_data = ir.READ_COL(in_frames[self.rhs], var.STR(self.rhs_field))
lhs_map = var.MAP(in_maps[self.lhs])
rhs_map = var.MAP(in_maps[self.rhs])
out_map = var.MAP(out_map)
return ir.BINARY_REDUCE(reducer, self.binary_op, graph, self.lhs,
self.rhs, lhs_data, rhs_data, out_size,
lhs_map, rhs_map, out_map)
@property
def name(self):
lhs = TargetCode.CODE2STR[self.lhs]
rhs = TargetCode.CODE2STR[self.rhs]
return "{}_{}_{}".format(lhs, self.binary_op, rhs)
class CopyMessageFunction(MessageFunction):
"""Class for the copy builtin message function.
See Also
--------
copy_src
"""
def __init__(self, target, in_field, out_field):
self.target = target
self.in_field = in_field
self.out_field = out_field
def _invoke(self, graph, src_frame, dst_frame, edge_frame, out_size,
src_map, dst_map, edge_map, out_map, reducer="none"):
"""Symbolic computation of builtin message function to create
runtime.executor
"""
graph = var.GRAPH(graph)
in_frames = (src_frame, dst_frame, edge_frame)
in_maps = (src_map, dst_map, edge_map)
in_data = ir.READ_COL(in_frames[self.target], var.STR(self.in_field))
in_map = var.MAP(in_maps[self.target])
out_map = var.MAP(out_map)
return ir.COPY_REDUCE(reducer, graph, self.target, in_data, out_size,
in_map, out_map)
@property
def name(self):
return "copy_{}".format(TargetCode.CODE2STR[self.target])
[docs]def copy_u(u, out):
"""Builtin message function that computes message using source node
feature.
Parameters
----------
u : str
The source feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_u('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.src['h']}
"""
return CopyMessageFunction(TargetCode.SRC, u, out)
[docs]def copy_e(e, out):
"""Builtin message function that computes message using edge feature.
Parameters
----------
e : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_e('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.data['h']}
"""
return CopyMessageFunction(TargetCode.EDGE, e, out)
###############################################################################
# Generate all following builtin message functions:
# element-wise message functions:
# u_add_v, u_sub_v, u_mul_v, u_div_v
# u_add_e, u_sub_e, u_mul_e, u_div_e
# v_add_u, v_sub_u, v_mul_u, v_div_u
# v_add_e, v_sub_e, v_mul_e, v_div_e
# e_add_u, e_sub_u, e_mul_u, e_div_u
# e_add_v, e_sub_v, e_mul_v, e_div_v
#
# dot message functions:
# u_dot_v, u_dot_e, v_dot_e
# v_dot_u, e_dot_u, e_dot_v
_TARGET_MAP = {
"u": TargetCode.SRC,
"v": TargetCode.DST,
"e": TargetCode.EDGE,
}
def _gen_message_builtin(lhs, rhs, binary_op):
name = "{}_{}_{}".format(lhs, binary_op, rhs)
docstring = """Builtin message function that computes a message on an edge
by performing element-wise {} between features of {} and {}
if the features have the same shape; otherwise, it first broadcasts the features
to a new shape and performs the element-wise operation.
Broadcasting follows NumPy semantics. Please see
https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
for more details about the NumPy broadcasting semantics.
Parameters
----------
lhs_field : str
The feature field of {}.
rhs_field : str
The feature field of {}.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.{}('h', 'h', 'm')
""".format(binary_op,
TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]],
TargetCode.CODE2STR[_TARGET_MAP[lhs]],
TargetCode.CODE2STR[_TARGET_MAP[rhs]],
name)
def func(lhs_field, rhs_field, out):
return BinaryMessageFunction(
binary_op, _TARGET_MAP[lhs],
_TARGET_MAP[rhs], lhs_field, rhs_field, out)
func.__name__ = name
func.__doc__ = docstring
return func
def _register_builtin_message_func():
"""Register builtin message functions"""
target = ["u", "v", "e"]
for lhs, rhs in product(target, target):
if lhs != rhs:
for binary_op in ["add", "sub", "mul", "div", "dot"]:
func = _gen_message_builtin(lhs, rhs, binary_op)
setattr(sys.modules[__name__], func.__name__, func)
__all__.append(func.__name__)
_register_builtin_message_func()
##############################################################################
# For backward compatibility
[docs]def src_mul_edge(src, edge, out):
"""Builtin message function that computes message by performing
binary operation mul between src feature and edge feature.
Notes
-----
This function is deprecated. Please use u_mul_e instead.
Parameters
----------
src : str
The source feature field.
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.src_mul_edge('h', 'e', 'm')
"""
return getattr(sys.modules[__name__], "u_mul_e")(src, edge, out)
[docs]def copy_src(src, out):
"""Builtin message function that computes message using source node
feature.
Notes
-----
This function is deprecated. Please use copy_u instead.
Parameters
----------
src : str
The source feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_src('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.src['h']}
"""
return copy_u(src, out)
[docs]def copy_edge(edge, out):
"""Builtin message function that computes message using edge feature.
Notes
-----
This function is deprecated. Please use copy_e instead.
Parameters
----------
edge : str
The edge feature field.
out : str
The output message field.
Examples
--------
>>> import dgl
>>> message_func = dgl.function.copy_edge('h', 'm')
The above example is equivalent to the following user defined function:
>>> def message_func(edges):
>>> return {'m': edges.data['h']}
"""
return copy_e(edge, out)