BiasedMHAΒΆ
-
class
dgl.nn.pytorch.gt.
BiasedMHA
(feat_size, num_heads, bias=True, attn_bias_type='add', attn_drop=0.1)[source]ΒΆ Bases:
torch.nn.modules.module.Module
Dense Multi-Head Attention Module with Graph Attention Bias.
Compute attention between nodes with attention bias obtained from graph structures, as introduced in Do Transformers Really Perform Bad for Graph Representation?
\[\text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b)\]\(Q\) and \(K\) are feature representations of nodes. \(d\) is the corresponding
feat_size
. \(b\) is attention bias, which can be additive or multiplicative according to the operator \(\circ\).- Parameters
feat_size (int) β Feature size.
num_heads (int) β Number of attention heads, by which
feat_size
is divisible.bias (bool, optional) β If True, it uses bias for linear projection. Default: True.
attn_bias_type (str, optional) β
The type of attention bias used for modifying attention. Selected from βaddβ or βmulβ. Default: βaddβ.
βaddβ is for additive attention bias.
βmulβ is for multiplicative attention bias.
attn_drop (float, optional) β Dropout probability on attention weights. Defalt: 0.1.
Examples
>>> import torch as th >>> from dgl.nn import BiasedMHA
>>> ndata = th.rand(16, 100, 512) >>> bias = th.rand(16, 100, 100, 8) >>> net = BiasedMHA(feat_size=512, num_heads=8) >>> out = net(ndata, bias)
-
forward
(ndata, attn_bias=None, attn_mask=None)[source]ΒΆ Forward computation.
- Parameters
ndata (torch.Tensor) β A 3D input tensor. Shape: (batch_size, N,
feat_size
), where N is the maximum number of nodes.attn_bias (torch.Tensor, optional) β The attention bias used for attention modification. Shape: (batch_size, N, N,
num_heads
).attn_mask (torch.Tensor, optional) β The attention mask used for avoiding computation on invalid positions, where invalid positions are indicated by True values. Shape: (batch_size, N, N). Note: For rows corresponding to unexisting nodes, make sure at least one entry is set to False to prevent obtaining NaNs with softmax.
- Returns
y β The output tensor. Shape: (batch_size, N,
feat_size
)- Return type
torch.Tensor