"""Tensorflow modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import tensorflow as tf
from tensorflow.keras import layers
from ...readout import sum_nodes, mean_nodes, max_nodes, \
softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling',
'MaxPooling', 'SortPooling', 'WeightAndSum', 'GlobalAttentionPooling']
[docs]class SumPooling(layers.Layer):
r"""Apply sum pooling over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(SumPooling, self).__init__()
[docs] def call(self, graph, feat):
r"""Compute sum pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
tf.Tensor
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = sum_nodes(graph, 'h')
return readout
[docs]class AvgPooling(layers.Layer):
r"""Apply average pooling over the nodes in the graph.
.. math::
r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(AvgPooling, self).__init__()
[docs] def call(self, graph, feat):
r"""Compute average pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
tf.Tensor
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = mean_nodes(graph, 'h')
return readout
[docs]class MaxPooling(layers.Layer):
r"""Apply max pooling over the nodes in the graph.
.. math::
r^{(i)} = \max_{k=1}^{N_i}\left( x^{(i)}_k \right)
"""
def __init__(self):
super(MaxPooling, self).__init__()
[docs] def call(self, graph, feat):
r"""Compute max pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
tf.Tensor
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = max_nodes(graph, 'h')
return readout
[docs]class SortPooling(layers.Layer):
r"""Apply Sort Pooling (`An End-to-End Deep Learning Architecture for Graph Classification
<https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__) over the nodes in the graph.
Parameters
----------
k : int
The number of nodes to hold for each graph.
"""
def __init__(self, k):
super(SortPooling, self).__init__()
self.k = k
[docs] def call(self, graph, feat):
r"""Compute sort pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
tf.Tensor
The output feature with shape :math:`(B, k * D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
# Sort the feature of each node in ascending order.
feat = tf.sort(feat, -1)
graph.ndata['h'] = feat
# Sort nodes according to their last features.
ret = tf.reshape(topk_nodes(graph, 'h', self.k, sortby=-1)[0], (
-1, self.k * feat.shape[-1]))
return ret
[docs]class GlobalAttentionPooling(layers.Layer):
r"""Apply Global Attention Pooling (`Gated Graph Sequence Neural Networks
<https://arxiv.org/abs/1511.05493.pdf>`__) over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate}
\left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)
Parameters
----------
gate_nn : tf.layers.Layer
A neural network that computes attention scores for each feature.
feat_nn : tf.layers.Layer, optional
A neural network applied to each feature before combining them
with attention scores.
"""
def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__()
self.gate_nn = gate_nn
self.feat_nn = feat_nn
[docs] def call(self, graph, feat):
r"""Compute global attention pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : tf.Tensor
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
tf.Tensor
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate
gate = softmax_nodes(graph, 'gate')
graph.ndata.pop('gate')
graph.ndata['r'] = feat * gate
readout = sum_nodes(graph, 'r')
graph.ndata.pop('r')
return readout
class WeightAndSum(layers.Layer):
"""Compute importance weights for atoms and perform a weighted sum.
Parameters
----------
in_feats : int
Input atom feature size
"""
def __init__(self, in_feats):
super(WeightAndSum, self).__init__()
self.in_feats = in_feats
self.atom_weighting = tf.keras.Sequential(
layers.Dense(1),
layers.Activation(tf.nn.sigmoid)
)
def call(self, g, feats):
"""Compute molecule representations out of atom representations
Parameters
----------
g : DGLGraph
DGLGraph with batch size B for processing multiple molecules in parallel
feats : FloatTensor of shape (N, self.in_feats)
Representations for all atoms in the molecules
* N is the total number of atoms in all molecules
Returns
-------
FloatTensor of shape (B, self.in_feats)
Representations for B molecules
"""
with g.local_scope():
g.ndata['h'] = feats
g.ndata['w'] = self.atom_weighting(g.ndata['h'])
h_g_sum = sum_nodes(g, 'h', 'w')
return h_g_sum