Source code for dgl.nn.pytorch.gt.egt

"""EGT Layer"""

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs]class EGTLayer(nn.Module): r"""EGTLayer for Edge-augmented Graph Transformer (EGT), as introduced in `Global Self-Attention as a Replacement for Graph Convolution Reference `<https://arxiv.org/pdf/2108.03348.pdf>`_ Parameters ---------- feat_size : int Node feature size. edge_feat_size : int Edge feature size. num_heads : int Number of attention heads, by which :attr: `feat_size` is divisible. num_virtual_nodes : int Number of virtual nodes. dropout : float, optional Dropout probability. Default: 0.0. attn_dropout : float, optional Attention dropout probability. Default: 0.0. activation : callable activation layer, optional Activation function. Default: nn.ELU(). edge_update : bool, optional Whether to update the edge embedding. Default: True. Examples -------- >>> import torch as th >>> from dgl.nn import EGTLayer >>> batch_size = 16 >>> num_nodes = 100 >>> feat_size, edge_feat_size = 128, 32 >>> nfeat = th.rand(batch_size, num_nodes, feat_size) >>> efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size) >>> net = EGTLayer( feat_size=feat_size, edge_feat_size=edge_feat_size, num_heads=8, num_virtual_nodes=4, ) >>> out = net(nfeat, efeat) """ def __init__( self, feat_size, edge_feat_size, num_heads, num_virtual_nodes, dropout=0, attn_dropout=0, activation=nn.ELU(), edge_update=True, ): super().__init__() self.num_heads = num_heads self.num_virtual_nodes = num_virtual_nodes self.edge_update = edge_update assert ( feat_size % num_heads == 0 ), "feat_size must be divisible by num_heads" self.dot_dim = feat_size // num_heads self.mha_ln_h = nn.LayerNorm(feat_size) self.mha_ln_e = nn.LayerNorm(edge_feat_size) self.edge_input = nn.Linear(edge_feat_size, num_heads) self.qkv_proj = nn.Linear(feat_size, feat_size * 3) self.gate = nn.Linear(edge_feat_size, num_heads) self.attn_dropout = nn.Dropout(attn_dropout) self.node_output = nn.Linear(feat_size, feat_size) self.mha_dropout_h = nn.Dropout(dropout) self.node_ffn = nn.Sequential( nn.LayerNorm(feat_size), nn.Linear(feat_size, feat_size), activation, nn.Linear(feat_size, feat_size), nn.Dropout(dropout), ) if self.edge_update: self.edge_output = nn.Linear(num_heads, edge_feat_size) self.mha_dropout_e = nn.Dropout(dropout) self.edge_ffn = nn.Sequential( nn.LayerNorm(edge_feat_size), nn.Linear(edge_feat_size, edge_feat_size), activation, nn.Linear(edge_feat_size, edge_feat_size), nn.Dropout(dropout), )
[docs] def forward(self, nfeat, efeat, mask=None): """Forward computation. Note: :attr:`nfeat` and :attr:`efeat` should be padded with embedding of virtual nodes if :attr:`num_virtual_nodes` > 0, while :attr:`mask` should be padded with `0` values for virtual nodes. The padding should be put at the beginning. Parameters ---------- nfeat : torch.Tensor A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where N is the sum of the maximum number of nodes and the number of virtual nodes. efeat : torch.Tensor Edge embedding used for attention computation and self update. Shape: (batch_size, N, N, :attr:`edge_feat_size`). mask : torch.Tensor, optional The attention mask used for avoiding computation on invalid positions, where valid positions are indicated by `0` and invalid positions are indicated by `-inf`. Shape: (batch_size, N, N). Default: None. Returns ------- nfeat : torch.Tensor The output node embedding. Shape: (batch_size, N, :attr:`feat_size`). efeat : torch.Tensor, optional The output edge embedding. Shape: (batch_size, N, N, :attr:`edge_feat_size`). It is returned only if :attr:`edge_update` is True. """ nfeat_r1 = nfeat efeat_r1 = efeat nfeat_ln = self.mha_ln_h(nfeat) efeat_ln = self.mha_ln_e(efeat) qkv = self.qkv_proj(nfeat_ln) e_bias = self.edge_input(efeat_ln) gates = self.gate(efeat_ln) bsz, N, _ = qkv.shape q_h, k_h, v_h = qkv.view(bsz, N, -1, self.num_heads).split( self.dot_dim, dim=2 ) attn_hat = torch.einsum("bldh,bmdh->blmh", q_h, k_h) attn_hat = attn_hat.clamp(-5, 5) + e_bias if mask is None: gates = torch.sigmoid(gates) attn_tild = F.softmax(attn_hat, dim=2) * gates else: gates = torch.sigmoid(gates + mask.unsqueeze(-1)) attn_tild = F.softmax(attn_hat + mask.unsqueeze(-1), dim=2) * gates attn_tild = self.attn_dropout(attn_tild) v_attn = torch.einsum("blmh,bmkh->blkh", attn_tild, v_h) # Scale the aggregated values by degree. degrees = torch.sum(gates, dim=2, keepdim=True) degree_scalers = torch.log(1 + degrees) degree_scalers[:, : self.num_virtual_nodes] = 1.0 v_attn = v_attn * degree_scalers v_attn = v_attn.reshape(bsz, N, self.num_heads * self.dot_dim) nfeat = self.node_output(v_attn) nfeat = self.mha_dropout_h(nfeat) nfeat.add_(nfeat_r1) nfeat_r2 = nfeat nfeat = self.node_ffn(nfeat) nfeat.add_(nfeat_r2) if self.edge_update: efeat = self.edge_output(attn_hat) efeat = self.mha_dropout_e(efeat) efeat.add_(efeat_r1) efeat_r2 = efeat efeat = self.edge_ffn(efeat) efeat.add_(efeat_r2) return nfeat, efeat return nfeat