Set2Set

class dgl.nn.pytorch.glob.Set2Set(input_dim, n_iters, n_layers)[source]

Bases: torch.nn.modules.module.Module

Set2Set operator from Order Matters: Sequence to sequence for sets

For each individual graph in the batch, set2set computes

\[ \begin{align}\begin{aligned}q_t &= \mathrm{LSTM} (q^*_{t-1})\\\alpha_{i,t} &= \mathrm{softmax}(x_i \cdot q_t)\\r_t &= \sum_{i=1}^N \alpha_{i,t} x_i\\q^*_t &= q_t \Vert r_t\end{aligned}\end{align} \]

for this graph.

Parameters
  • input_dim (int) – The size of each input sample.

  • n_iters (int) – The number of iterations.

  • n_layers (int) – The number of recurrent layers.

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch as th
>>> from dgl.nn import Set2Set
>>>
>>> g1 = dgl.rand_graph(3, 4)  # g1 is a random graph with 3 nodes and 4 edges
>>> g1_node_feats = th.rand(3, 5)  # feature size is 5
>>> g1_node_feats
tensor([[0.8948, 0.0699, 0.9137, 0.7567, 0.3637],
        [0.8137, 0.8938, 0.8377, 0.4249, 0.6118],
        [0.5197, 0.9030, 0.6825, 0.5725, 0.4755]])
>>>
>>> g2 = dgl.rand_graph(4, 6)  # g2 is a random graph with 4 nodes and 6 edges
>>> g2_node_feats = th.rand(4, 5)  # feature size is 5
>>> g2_node_feats
tensor([[0.2053, 0.2426, 0.4111, 0.9028, 0.5658],
        [0.5278, 0.6365, 0.9990, 0.2351, 0.8945],
        [0.3134, 0.0580, 0.4349, 0.7949, 0.3891],
        [0.0142, 0.2709, 0.3330, 0.8521, 0.6925]])
>>>
>>> s2s = Set2Set(5, 2, 1)  # create a Set2Set layer(n_iters=2, n_layers=1)

Case 1: Input a single graph

>>> s2s(g1, g1_node_feats)
    tensor([[-0.0235, -0.2291,  0.2654,  0.0376,  0.1349,  0.7560,  0.5822,  0.8199,
              0.5960,  0.4760]], grad_fn=<CatBackward>)

Case 2: Input a batch of graphs

Build a batch of DGL graphs and concatenate all graphs’ node features into one tensor.

>>> batch_g = dgl.batch([g1, g2])
>>> batch_f = th.cat([g1_node_feats, g2_node_feats], 0)
>>>
>>> s2s(batch_g, batch_f)
tensor([[-0.0235, -0.2291,  0.2654,  0.0376,  0.1349,  0.7560,  0.5822,  0.8199,
          0.5960,  0.4760],
        [-0.0483, -0.2010,  0.2324,  0.0145,  0.1361,  0.2703,  0.3078,  0.5529,
          0.6876,  0.6399]], grad_fn=<CatBackward>)

Notes

Set2Set is widely used in molecular property predictions, see dgl-lifesci’s MPNN example on how to use DGL’s Set2Set layer in graph property prediction applications.

forward(graph, feat)[source]

Compute set2set pooling.

Parameters
  • graph (DGLGraph) – The input graph.

  • feat (torch.Tensor) – The input feature with shape \((N, D)\) where \(N\) is the number of nodes in the graph, and \(D\) means the size of features.

Returns

The output feature with shape \((B, D)\), where \(B\) refers to the batch size, and \(D\) means the size of features.

Return type

torch.Tensor

reset_parameters()[source]

Reinitialize learnable parameters.