Source code for dgl.nn.pytorch.gt.lap_pos_encoder

"""Laplacian Positional Encoder"""

import torch as th
import torch.nn as nn


[docs]class LapPosEncoder(nn.Module): r"""Laplacian Positional Encoder (LPE), as introduced in `GraphGPS: General Powerful Scalable Graph Transformers <https://arxiv.org/abs/2205.12454>`__ This module is a learned laplacian positional encoding module using Transformer or DeepSet. Parameters ---------- model_type : str Encoder model type for LPE, can only be "Transformer" or "DeepSet". num_layer : int Number of layers in Transformer/DeepSet Encoder. k : int Number of smallest non-trivial eigenvectors. dim : int Output size of final laplacian encoding. n_head : int, optional Number of heads in Transformer Encoder. Default : 1. batch_norm : bool, optional If True, apply batch normalization on raw laplacian positional encoding. Default : False. num_post_layer : int, optional If num_post_layer > 0, apply an MLP of ``num_post_layer`` layers after pooling. Default : 0. Example ------- >>> import dgl >>> from dgl import LapPE >>> from dgl.nn import LapPosEncoder >>> transform = LapPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True) >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4])) >>> g = transform(g) >>> eigvals, eigvecs = g.ndata['eigval'], g.ndata['eigvec'] >>> transformer_encoder = LapPosEncoder( model_type="Transformer", num_layer=3, k=5, dim=16, n_head=4 ) >>> pos_encoding = transformer_encoder(eigvals, eigvecs) >>> deepset_encoder = LapPosEncoder( model_type="DeepSet", num_layer=3, k=5, dim=16, num_post_layer=2 ) >>> pos_encoding = deepset_encoder(eigvals, eigvecs) """ def __init__( self, model_type, num_layer, k, dim, n_head=1, batch_norm=False, num_post_layer=0, ): super(LapPosEncoder, self).__init__() self.model_type = model_type self.linear = nn.Linear(2, dim) if self.model_type == "Transformer": encoder_layer = nn.TransformerEncoderLayer( d_model=dim, nhead=n_head, batch_first=True ) self.pe_encoder = nn.TransformerEncoder( encoder_layer, num_layers=num_layer ) elif self.model_type == "DeepSet": layers = [] if num_layer == 1: layers.append(nn.ReLU()) else: self.linear = nn.Linear(2, 2 * dim) layers.append(nn.ReLU()) for _ in range(num_layer - 2): layers.append(nn.Linear(2 * dim, 2 * dim)) layers.append(nn.ReLU()) layers.append(nn.Linear(2 * dim, dim)) layers.append(nn.ReLU()) self.pe_encoder = nn.Sequential(*layers) else: raise ValueError( f"model_type '{model_type}' is not allowed, must be " "'Transformer' or 'DeepSet'." ) if batch_norm: self.raw_norm = nn.BatchNorm1d(k) else: self.raw_norm = None if num_post_layer > 0: layers = [] if num_post_layer == 1: layers.append(nn.Linear(dim, dim)) layers.append(nn.ReLU()) else: layers.append(nn.Linear(dim, 2 * dim)) layers.append(nn.ReLU()) for _ in range(num_post_layer - 2): layers.append(nn.Linear(2 * dim, 2 * dim)) layers.append(nn.ReLU()) layers.append(nn.Linear(2 * dim, dim)) layers.append(nn.ReLU()) self.post_mlp = nn.Sequential(*layers) else: self.post_mlp = None
[docs] def forward(self, eigvals, eigvecs): r""" Parameters ---------- eigvals : Tensor Laplacian Eigenvalues of shape :math:`(N, k)`, k different eigenvalues repeat N times, can be obtained by using `LaplacianPE`. eigvecs : Tensor Laplacian Eigenvectors of shape :math:`(N, k)`, can be obtained by using `LaplacianPE`. Returns ------- Tensor Return the laplacian positional encodings of shape :math:`(N, d)`, where :math:`N` is the number of nodes in the input graph, :math:`d` is :attr:`dim`. """ pos_encoding = th.cat( (eigvecs.unsqueeze(2), eigvals.unsqueeze(2)), dim=2 ).float() empty_mask = th.isnan(pos_encoding) pos_encoding[empty_mask] = 0 if self.raw_norm: pos_encoding = self.raw_norm(pos_encoding) pos_encoding = self.linear(pos_encoding) if self.model_type == "Transformer": pos_encoding = self.pe_encoder( src=pos_encoding, src_key_padding_mask=empty_mask[:, :, 1] ) else: pos_encoding = self.pe_encoder(pos_encoding) # Remove masked sequences. pos_encoding[empty_mask[:, :, 1]] = 0 # Sum pooling. pos_encoding = th.sum(pos_encoding, 1, keepdim=False) # MLP post pooling. if self.post_mlp: pos_encoding = self.post_mlp(pos_encoding) return pos_encoding