EGTLayer

class dgl.nn.pytorch.gt.EGTLayer(feat_size, edge_feat_size, num_heads, num_virtual_nodes, dropout=0, attn_dropout=0, activation=ELU(alpha=1.0), edge_update=True)[source]

Bases: Module

EGTLayer for Edge-augmented Graph Transformer (EGT), as introduced in `Global Self-Attention as a Replacement for Graph Convolution Reference `<https://arxiv.org/pdf/2108.03348.pdf>`_

Parameters:
  • feat_size (int) – Node feature size.

  • edge_feat_size (int) – Edge feature size.

  • num_heads (int) – Number of attention heads, by which :attr: feat_size is divisible.

  • num_virtual_nodes (int) – Number of virtual nodes.

  • dropout (float, optional) – Dropout probability. Default: 0.0.

  • attn_dropout (float, optional) – Attention dropout probability. Default: 0.0.

  • activation (callable activation layer, optional) – Activation function. Default: nn.ELU().

  • edge_update (bool, optional) – Whether to update the edge embedding. Default: True.

Examples

>>> import torch as th
>>> from dgl.nn import EGTLayer
>>> batch_size = 16
>>> num_nodes = 100
>>> feat_size, edge_feat_size = 128, 32
>>> nfeat = th.rand(batch_size, num_nodes, feat_size)
>>> efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size)
>>> net = EGTLayer(
        feat_size=feat_size,
        edge_feat_size=edge_feat_size,
        num_heads=8,
        num_virtual_nodes=4,
    )
>>> out = net(nfeat, efeat)
forward(nfeat, efeat, mask=None)[source]

Forward computation. Note: nfeat and efeat should be padded with embedding of virtual nodes if num_virtual_nodes > 0, while mask should be padded with 0 values for virtual nodes. The padding should be put at the beginning.

Parameters:
  • nfeat (torch.Tensor) – A 3D input tensor. Shape: (batch_size, N, feat_size), where N is the sum of the maximum number of nodes and the number of virtual nodes.

  • efeat (torch.Tensor) – Edge embedding used for attention computation and self update. Shape: (batch_size, N, N, edge_feat_size).

  • mask (torch.Tensor, optional) – The attention mask used for avoiding computation on invalid positions, where valid positions are indicated by 0 and invalid positions are indicated by -inf. Shape: (batch_size, N, N). Default: None.

Returns:

  • nfeat (torch.Tensor) – The output node embedding. Shape: (batch_size, N, feat_size).

  • efeat (torch.Tensor, optional) – The output edge embedding. Shape: (batch_size, N, N, edge_feat_size). It is returned only if edge_update is True.