"""MXNet modules for graph global pooling."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
from mxnet import gluon, nd
from mxnet.gluon import nn
from ...readout import sum_nodes, mean_nodes, max_nodes, broadcast_nodes,\
softmax_nodes, topk_nodes
__all__ = ['SumPooling', 'AvgPooling', 'MaxPooling', 'SortPooling',
'GlobalAttentionPooling', 'Set2Set']
[docs]class SumPooling(nn.Block):
r"""Apply sum pooling over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(SumPooling, self).__init__()
[docs] def forward(self, graph, feat):
r"""Compute sum pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = sum_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'SumPooling()'
[docs]class AvgPooling(nn.Block):
r"""Apply average pooling over the nodes in the graph.
.. math::
r^{(i)} = \frac{1}{N_i}\sum_{k=1}^{N_i} x^{(i)}_k
"""
def __init__(self):
super(AvgPooling, self).__init__()
[docs] def forward(self, graph, feat):
r"""Compute average pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = mean_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'AvgPooling()'
[docs]class MaxPooling(nn.Block):
r"""Apply max pooling over the nodes in the graph.
.. math::
r^{(i)} = \max_{k=1}^{N_i} \left( x^{(i)}_k \right)
"""
def __init__(self):
super(MaxPooling, self).__init__()
[docs] def forward(self, graph, feat):
r"""Compute max pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, *)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(B, *)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
graph.ndata['h'] = feat
readout = max_nodes(graph, 'h')
graph.ndata.pop('h')
return readout
def __repr__(self):
return 'MaxPooling()'
[docs]class SortPooling(nn.Block):
r"""Apply Sort Pooling (`An End-to-End Deep Learning Architecture for Graph Classification
<https://www.cse.wustl.edu/~ychen/public/DGCNN.pdf>`__) over the nodes in the graph.
Parameters
----------
k : int
The number of nodes to hold for each graph.
"""
def __init__(self, k):
super(SortPooling, self).__init__()
self.k = k
[docs] def forward(self, graph, feat):
r"""Compute sort pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(B, k * D)`, where
:math:`B` refers to the batch size.
"""
# Sort the feature of each node in ascending order.
with graph.local_scope():
feat = feat.sort(axis=-1)
graph.ndata['h'] = feat
# Sort nodes according to their last features.
ret = topk_nodes(graph, 'h', self.k, sortby=-1)[0].reshape(
-1, self.k * feat.shape[-1])
return ret
def __repr__(self):
return 'SortPooling(k={})'.format(self.k)
[docs]class GlobalAttentionPooling(nn.Block):
r"""Apply Global Attention Pooling (`Gated Graph Sequence Neural Networks
<https://arxiv.org/abs/1511.05493.pdf>`__) over the nodes in the graph.
.. math::
r^{(i)} = \sum_{k=1}^{N_i}\mathrm{softmax}\left(f_{gate}
\left(x^{(i)}_k\right)\right) f_{feat}\left(x^{(i)}_k\right)
Parameters
----------
gate_nn : gluon.nn.Block
A neural network that computes attention scores for each feature.
feat_nn : gluon.nn.Block, optional
A neural network applied to each feature before combining them
with attention scores.
"""
def __init__(self, gate_nn, feat_nn=None):
super(GlobalAttentionPooling, self).__init__()
with self.name_scope():
self.gate_nn = gate_nn
self.feat_nn = feat_nn
[docs] def forward(self, graph, feat):
r"""Compute global attention pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
gate = self.gate_nn(feat)
assert gate.shape[-1] == 1, "The output of gate_nn should have size 1 at the last axis."
feat = self.feat_nn(feat) if self.feat_nn else feat
graph.ndata['gate'] = gate
gate = softmax_nodes(graph, 'gate')
graph.ndata['r'] = feat * gate
readout = sum_nodes(graph, 'r')
return readout
[docs]class Set2Set(nn.Block):
r"""Apply Set2Set (`Order Matters: Sequence to sequence for sets
<https://arxiv.org/pdf/1511.06391.pdf>`__) over the nodes in the graph.
For each individual graph in the batch, set2set computes
.. math::
q_t &= \mathrm{LSTM} (q^*_{t-1})
\alpha_{i,t} &= \mathrm{softmax}(x_i \cdot q_t)
r_t &= \sum_{i=1}^N \alpha_{i,t} x_i
q^*_t &= q_t \Vert r_t
for this graph.
Parameters
----------
input_dim : int
Size of each input sample
n_iters : int
Number of iterations.
n_layers : int
Number of recurrent layers.
"""
def __init__(self, input_dim, n_iters, n_layers):
super(Set2Set, self).__init__()
self.input_dim = input_dim
self.output_dim = 2 * input_dim
self.n_iters = n_iters
self.n_layers = n_layers
with self.name_scope():
self.lstm = gluon.rnn.LSTM(
self.input_dim, num_layers=n_layers, input_size=self.output_dim)
[docs] def forward(self, graph, feat):
r"""Compute set2set pooling.
Parameters
----------
graph : DGLGraph
The graph.
feat : mxnet.NDArray
The input feature with shape :math:`(N, D)` where
:math:`N` is the number of nodes in the graph.
Returns
-------
mxnet.NDArray
The output feature with shape :math:`(B, D)`, where
:math:`B` refers to the batch size.
"""
with graph.local_scope():
batch_size = graph.batch_size
h = (nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context),
nd.zeros((self.n_layers, batch_size, self.input_dim), ctx=feat.context))
q_star = nd.zeros((batch_size, self.output_dim), ctx=feat.context)
for _ in range(self.n_iters):
q, h = self.lstm(q_star.expand_dims(axis=0), h)
q = q.reshape((batch_size, self.input_dim))
e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True)
graph.ndata['e'] = e
alpha = softmax_nodes(graph, 'e')
graph.ndata['r'] = feat * alpha
readout = sum_nodes(graph, 'r')
q_star = nd.concat(q, readout, dim=-1)
return q_star
def __repr__(self):
summary = 'Set2Set('
summary += 'in={}, out={}, ' \
'n_iters={}, n_layers={}'.format(self.input_dim,
self.output_dim,
self.n_iters,
self.n_layers)
summary += ')'
return summary