GlobalAttentionPooling

class dgl.nn.tensorflow.glob.GlobalAttentionPooling(*args, **kwargs)[source]

Bases: tensorflow.python.keras.engine.base_layer.Layer

Global Attention Pooling from Gated Graph Sequence Neural Networks

\[r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate} \left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)\]
Parameters
  • gate_nn (tf.layers.Layer) – A neural network that computes attention scores for each feature.

  • feat_nn (tf.layers.Layer, optional) – A neural network applied to each feature before combining them with attention scores.