dgl.nn.pytorch

dgl.nn.pytorch.conv

Torch modules for graph convolutions.

class dgl.nn.pytorch.conv.GraphConv(in_feats, out_feats, norm=True, bias=True, activation=None)[source]

Bases: torch.nn.modules.module.Module

Apply graph convolution over an input signal.

Graph convolution is introduced in GCN and can be described as below:

\[h_i^{(l+1)} = \sigma(b^{(l)} + \sum_{j\in\mathcal{N}(i)}\frac{1}{c_{ij}}h_j^{(l)}W^{(l)})\]

where \(\mathcal{N}(i)\) is the neighbor set of node \(i\). \(c_{ij}\) is equal to the product of the square root of node degrees: \(\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}\). \(\sigma\) is an activation function.

The model parameters are initialized as in the original implementation where the weight \(W^{(l)}\) is initialized using Glorot uniform initialization and the bias is initialized to be zero.

Notes

Zero in degree nodes could lead to invalid normalizer. A common practice to avoid this is to add a self-loop for each node in the graph, which can be achieved by:

>>> g = ... # some DGLGraph
>>> g.add_edges(g.nodes(), g.nodes())
Parameters:
  • in_feats (int) – Number of input features.
  • out_feats (int) – Number of output features.
  • norm (bool, optional) – If True, the normalizer \(c_{ij}\) is applied. Default: True.
  • bias (bool, optional) – If True, adds a learnable bias to the output. Default: True.
  • activation (callable activation function/layer or None, optional) – If not None, applies an activation function to the updated node features. Default: None.
weight

torch.Tensor – The learnable weight tensor.

bias

torch.Tensor – The learnable bias tensor.

forward(feat, graph)[source]

Compute graph convolution.

Notes

  • Input shape: \((N, *, \text{in_feats})\) where * means any number of additional dimensions, \(N\) is the number of nodes.
  • Output shape: \((N, *, \text{out_feats})\) where all but the last dimension are the same shape as the input.
Parameters:
  • feat (torch.Tensor) – The input feature
  • graph (DGLGraph) – The graph.
Returns:

The output feature

Return type:

torch.Tensor

reset_parameters()[source]

Reinitialize learnable parameters.

dgl.nn.pytorch.softmax

Torch modules for graph related softmax.

class dgl.nn.pytorch.softmax.EdgeSoftmax[source]

Bases: torch.nn.modules.module.Module

Apply softmax over signals of incoming edges.

For a node \(i\), edgesoftmax is an operation of computing

\[a_{ij} = \frac{\exp(z_{ij})}{\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})}\]

where \(z_{ij}\) is a signal of edge \(j\rightarrow i\), also called logits in the context of softmax. \(\mathcal{N}(i)\) is the set of nodes that have an edge to \(i\).

An example of using edgesoftmax is in Graph Attention Network where the attention weights are computed with such an edgesoftmax operation.

forward(logits, graph)[source]

Compute edge softmax.

Parameters:
  • logits (torch.Tensor) – The input edge feature
  • graph (DGLGraph) – The graph.
Returns:

  • Unnormalized scores (torch.Tensor) – This part gives \(\exp(z_{ij})\)‘s
  • Normalizer (torch.Tensor) – This part gives \(\sum_{j\in\mathcal{N}(i)}\exp(z_{ij})\)

Notes

  • Input shape: \((N, *, 1)\) where * means any number of additional dimensions, \(N\) is the number of edges.
  • Unnormalized scores shape: \((N, *, 1)\) where all but the last dimension are the same shape as the input.
  • Normalizer shape: \((M, *, 1)\) where \(M\) is the number of nodes and all but the first and the last dimensions are the same as the input.

Note that this computation is still one step away from getting real softmax results. The last step can be proceeded as follows:

>>> import dgl.function as fn
>>>
>>> scores, normalizer = EdgeSoftmax(...).forward(logits, graph)
>>> graph.edata['a'] = scores
>>> graph.ndata['normalizer'] = normalizer
>>> graph.apply_edges(lambda edges : {'a' : edges.data['a'] / edges.dst['normalizer']})

We left this last step to users as depending on the particular use case, this step can be combined with other computation at once.