SpatialEncoderΒΆ
-
class
dgl.nn.pytorch.gt.
SpatialEncoder
(max_dist, num_heads=1)[source]ΒΆ Bases:
torch.nn.modules.module.Module
Spatial Encoder, as introduced in Do Transformers Really Perform Bad for Graph Representation?
This module is a learnable spatial embedding module, which encodes the shortest distance between each node pair for attention bias.
- Parameters
Examples
>>> import torch as th >>> import dgl >>> from dgl.nn import SpatialEncoder >>> from dgl import shortest_dist
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1])) >>> g2 = dgl.graph(([0,1], [1,0])) >>> n1, n2 = g1.num_nodes(), g2.num_nodes() >>> # use -1 padding since shortest_dist returns -1 for unreachable node pairs >>> dist = -th.ones((2, 4, 4), dtype=th.long) >>> dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False) >>> dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False) >>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8) >>> out = spatial_encoder(dist) >>> print(out.shape) torch.Size([2, 4, 4, 8])
-
forward
(dist)[source]ΒΆ - Parameters
dist (Tensor) β Shortest path distance of the batched graph with -1 padding, a tensor of shape \((B, N, N)\), where \(B\) is the batch size of the batched graph, and \(N\) is the maximum number of nodes.
- Returns
Return attention bias as spatial encoding of shape \((B, N, N, H)\), where \(H\) is
num_heads
.- Return type
torch.Tensor