GINConv

class dgl.nn.tensorflow.conv.GINConv(apply_func, aggregator_type, init_eps=0, learn_eps=False)[source]

Bases: tensorflow.python.keras.engine.base_layer.Layer

Graph Isomorphism Network layer from How Powerful are Graph Neural Networks?

\[h_i^{(l+1)} = f_\Theta \left((1 + \epsilon) h_i^{l} + \mathrm{aggregate}\left(\left\{h_j^{l}, j\in\mathcal{N}(i) \right\}\right)\right)\]
Parameters
  • apply_func (callable activation function/layer or None) – If not None, apply this function to the updated node feature, the \(f_\Theta\) in the formula.

  • aggregator_type (str) – Aggregator type to use (sum, max or mean).

  • init_eps (float, optional) – Initial \(\epsilon\) value, default: 0.

  • learn_eps (bool, optional) – If True, \(\epsilon\) will be a learnable parameter. Default: False.

Example

>>> import dgl
>>> import numpy as np
>>> import tensorflow as tf
>>> from dgl.nn import GINConv
>>>
>>> with tf.device("CPU:0"):
>>>     g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>>     feat = tf.ones((6, 10))
>>>     lin = tf.keras.layers.Dense(10)
>>>     conv = GINConv(lin, 'max')
>>>     res = conv(g, feat)
>>>     res
<tf.Tensor: shape=(6, 10), dtype=float32, numpy=
array([[-0.1090256 ,  1.9050574 , -0.30704725, -1.995831  , -0.36399186,
        1.10414   ,  2.4885745 , -0.35387516,  1.3568261 ,  1.7267858 ],
    [-0.1090256 ,  1.9050574 , -0.30704725, -1.995831  , -0.36399186,
        1.10414   ,  2.4885745 , -0.35387516,  1.3568261 ,  1.7267858 ],
    [-0.1090256 ,  1.9050574 , -0.30704725, -1.995831  , -0.36399186,
        1.10414   ,  2.4885745 , -0.35387516,  1.3568261 ,  1.7267858 ],
    [-0.1090256 ,  1.9050574 , -0.30704725, -1.995831  , -0.36399186,
        1.10414   ,  2.4885745 , -0.35387516,  1.3568261 ,  1.7267858 ],
    [-0.1090256 ,  1.9050574 , -0.30704725, -1.995831  , -0.36399186,
        1.10414   ,  2.4885745 , -0.35387516,  1.3568261 ,  1.7267858 ],
    [-0.0545128 ,  0.9525287 , -0.15352362, -0.9979155 , -0.18199593,
        0.55207   ,  1.2442873 , -0.17693758,  0.67841303,  0.8633929 ]],
    dtype=float32)>