"""Tensorflow Module for Relational graph convolution layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import tensorflow as tf
from tensorflow.keras import layers
from .... import function as fn
from .. import utils
[docs]class RelGraphConv(layers.Layer):
r"""
Description
-----------
Relational graph convolution layer.
Relational graph convolution is introduced in "`Modeling Relational Data with Graph
Convolutional Networks <https://arxiv.org/abs/1703.06103>`__"
and can be described as below:
.. math::
h_i^{(l+1)} = \sigma(\sum_{r\in\mathcal{R}}
\sum_{j\in\mathcal{N}^r(i)}\frac{1}{c_{i,r}}W_r^{(l)}h_j^{(l)}+W_0^{(l)}h_i^{(l)})
where :math:`\mathcal{N}^r(i)` is the neighbor set of node :math:`i` w.r.t. relation
:math:`r`. :math:`c_{i,r}` is the normalizer equal
to :math:`|\mathcal{N}^r(i)|`. :math:`\sigma` is an activation function. :math:`W_0`
is the self-loop weight.
The basis regularization decomposes :math:`W_r` by:
.. math::
W_r^{(l)} = \sum_{b=1}^B a_{rb}^{(l)}V_b^{(l)}
where :math:`B` is the number of bases, :math:`V_b^{(l)}` are linearly combined
with coefficients :math:`a_{rb}^{(l)}`.
The block-diagonal-decomposition regularization decomposes :math:`W_r` into :math:`B`
number of block diagonal matrices. We refer :math:`B` as the number of bases.
The block regularization decomposes :math:`W_r` by:
.. math::
W_r^{(l)} = \oplus_{b=1}^B Q_{rb}^{(l)}
where :math:`B` is the number of bases, :math:`Q_{rb}^{(l)}` are block
bases with shape :math:`R^{(d^{(l+1)}/B)*(d^{l}/B)}`.
Parameters
----------
in_feat : int
Input feature size; i.e, the number of dimensions of :math:`h_j^{(l)}`.
out_feat : int
Output feature size; i.e., the number of dimensions of :math:`h_i^{(l+1)}`.
num_rels : int
Number of relations. .
regularizer : str
Which weight regularizer to use "basis" or "bdd".
"basis" is short for basis-diagonal-decomposition.
"bdd" is short for block-diagonal-decomposition.
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: ``None``.
bias : bool, optional
True if bias is added. Default: ``True``.
activation : callable, optional
Activation function. Default: ``None``.
self_loop : bool, optional
True to include self loop message. Default: ``True``.
low_mem : bool, optional
True to use low memory implementation of relation message passing function. Default: False.
This option trades speed with memory consumption, and will slowdown the forward/backward.
Turn it on when you encounter OOM problem during training or evaluation. Default: ``False``.
dropout : float, optional
Dropout rate. Default: ``0.0``
layer_norm: float, optional
Add layer norm. Default: ``False``
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import tensorflow as tf
>>> from dgl.nn import RelGraphConv
>>>
>>> with tf.device("CPU:0"):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = tf.ones((6, 10))
>>> conv = RelGraphConv(10, 2, 3, regularizer='basis', num_bases=2)
>>> etype = tf.convert_to_tensor(np.array([0,1,2,0,1,2]).astype(np.int64))
>>> res = conv(g, feat, etype)
>>> res
<tf.Tensor: shape=(6, 2), dtype=float32, numpy=
array([[-0.02938664, 1.7932655 ],
[ 0.1146394 , 0.48319 ],
[-0.02938664, 1.7932655 ],
[ 1.2054908 , -0.26098895],
[ 0.1146394 , 0.48319 ],
[ 0.75915515, 1.1454091 ]], dtype=float32)>
>>> # One-hot input
>>> with tf.device("CPU:0"):
>>> one_hot_feat = tf.convert_to_tensor(np.array([0,1,2,3,4,5]).astype(np.int64))
>>> res = conv(g, one_hot_feat, etype)
>>> res
<tf.Tensor: shape=(6, 2), dtype=float32, numpy=
array([[-0.24205256, -0.7922753 ],
[ 0.62085056, 0.4893622 ],
[-0.9484881 , -0.26546806],
[-0.2163915 , -0.12585883],
[-0.14293689, 0.77483284],
[ 0.091169 , -0.06761569]], dtype=float32)>
"""
def __init__(self,
in_feat,
out_feat,
num_rels,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=True,
low_mem=False,
dropout=0.0,
layer_norm=False):
super(RelGraphConv, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.num_rels = num_rels
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0:
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
self.self_loop = self_loop
self.low_mem = low_mem
assert layer_norm is False, 'TensorFlow currently does not support layer norm.'
xinit = tf.keras.initializers.glorot_uniform()
zeroinit = tf.keras.initializers.zeros()
if regularizer == "basis":
# add basis weights
self.weight = tf.Variable(initial_value=xinit(
shape=(self.num_bases, self.in_feat, self.out_feat),
dtype='float32'), trainable=True)
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = tf.Variable(initial_value=xinit(
shape=(self.num_rels, self.num_bases), dtype='float32'), trainable=True)
# message func
self.message_func = self.basis_message_func
elif regularizer == "bdd":
if in_feat % num_bases != 0 or out_feat % num_bases != 0:
raise ValueError(
'Feature size must be a multiplier of num_bases.')
# add block diagonal weights
self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases
# assuming in_feat and out_feat are both divisible by num_bases
self.weight = tf.Variable(initial_value=xinit(
shape=(self.num_rels, self.num_bases *
self.submat_in * self.submat_out),
dtype='float32'), trainable=True)
# message func
self.message_func = self.bdd_message_func
else:
raise ValueError("Regularizer must be either 'basis' or 'bdd'")
# bias
if self.bias:
self.h_bias = tf.Variable(initial_value=zeroinit(
shape=(out_feat), dtype='float32'), trainable=True)
# weight for self loop
if self.self_loop:
self.loop_weight = tf.Variable(initial_value=xinit(
shape=(in_feat, out_feat), dtype='float32'), trainable=True)
self.dropout = layers.Dropout(rate=dropout)
def basis_message_func(self, edges):
"""Message function for basis regularizer"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = tf.reshape(self.weight, (self.num_bases,
self.in_feat * self.out_feat))
weight = tf.reshape(tf.matmul(self.w_comp, weight), (
self.num_rels, self.in_feat, self.out_feat))
else:
weight = self.weight
# calculate msg @ W_r before put msg into edge
# if src is th.int64 we expect it is an index select
if edges.src['h'].dtype != tf.int64 and self.low_mem:
etypes, _ = tf.unique(edges.data['type'])
msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
idx = tf.range(edges.src['h'].shape[0])
for etype in etypes:
loc = (edges.data['type'] == etype)
w = weight[etype]
src = tf.boolean_mask(edges.src['h'], loc)
sub_msg = tf.matmul(src, w)
indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
else:
msg = utils.bmm_maybe_select(
edges.src['h'], weight, edges.data['type'])
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
def bdd_message_func(self, edges):
"""Message function for block-diagonal-decomposition regularizer"""
if ((edges.src['h'].dtype == tf.int64) and
len(edges.src['h'].shape) == 1):
raise TypeError(
'Block decomposition does not allow integer ID feature.')
# calculate msg @ W_r before put msg into edge
# if src is th.int64 we expect it is an index select
if self.low_mem:
etypes, _ = tf.unique(edges.data['type'])
msg = tf.zeros([edges.src['h'].shape[0], self.out_feat])
idx = tf.range(edges.src['h'].shape[0])
for etype in etypes:
loc = (edges.data['type'] == etype)
w = tf.reshape(self.weight[etype],
(self.num_bases, self.submat_in, self.submat_out))
src = tf.reshape(tf.boolean_mask(edges.src['h'], loc),
(-1, self.num_bases, self.submat_in))
sub_msg = tf.einsum('abc,bcd->abd', src, w)
sub_msg = tf.reshape(sub_msg, (-1, self.out_feat))
indices = tf.reshape(tf.boolean_mask(idx, loc), (-1, 1))
msg = tf.tensor_scatter_nd_update(msg, indices, sub_msg)
else:
weight = tf.reshape(tf.gather(
self.weight, edges.data['type']), (-1, self.submat_in, self.submat_out))
node = tf.reshape(edges.src['h'], (-1, 1, self.submat_in))
msg = tf.reshape(tf.matmul(node, weight), (-1, self.out_feat))
if 'norm' in edges.data:
msg = msg * edges.data['norm']
return {'msg': msg}
def call(self, g, x, etypes, norm=None):
""" Forward computation
Parameters
----------
g : DGLGraph
The graph.
x : tf.Tensor
Input node features. Could be either
* :math:`(|V|, D)` dense tensor
* :math:`(|V|,)` int64 vector, representing the categorical values of each
node. We then treat the input feature as an one-hot encoding feature.
etypes : tf.Tensor
Edge type tensor. Shape: :math:`(|E|,)`
norm : tf.Tensor
Optional edge normalizer tensor. Shape: :math:`(|E|, 1)`
Returns
-------
tf.Tensor
New node features.
"""
assert g.is_homogeneous, \
"not a homogeneous graph; convert it with to_homogeneous " \
"and pass in the edge type as argument"
with g.local_scope():
g.ndata['h'] = x
g.edata['type'] = tf.cast(etypes, tf.int64)
if norm is not None:
g.edata['norm'] = norm
if self.self_loop:
loop_message = utils.matmul_maybe_select(x, self.loop_weight)
# message passing
g.update_all(self.message_func, fn.sum(msg='msg', out='h'))
# apply bias and activation
node_repr = g.ndata['h']
if self.bias:
node_repr = node_repr + self.h_bias
if self.self_loop:
node_repr = node_repr + loop_message
if self.activation:
node_repr = self.activation(node_repr)
node_repr = self.dropout(node_repr)
return node_repr