GlobalAttentionPoolingΒΆ
-
class
dgl.nn.tensorflow.glob.
GlobalAttentionPooling
(*args, **kwargs)[source]ΒΆ Bases:
tensorflow.python.keras.engine.base_layer.Layer
Global Attention Pooling from Gated Graph Sequence Neural Networks
\[r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate} \left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)\]- Parameters
gate_nn (tf.layers.Layer) β A neural network that computes attention scores for each feature.
feat_nn (tf.layers.Layer, optional) β A neural network applied to each feature before combining them with attention scores.