Source code for dgl.nn.pytorch.conv.gatedgcnconv

"""Torch Module for GatedGCN layer"""
# pylint: disable= no-member, arguments-differ, invalid-name, cell-var-from-loop
import torch
import torch.nn.functional as F
from torch import nn

from .... import function as fn


[docs]class GatedGCNConv(nn.Module): r"""Gated graph convolutional layer from `Benchmarking Graph Neural Networks <https://arxiv.org/abs/2003.00982>`__ .. math:: e_{ij}^{l+1}=D^l h_{i}^{l}+E^l h_{j}^{l}+C^l e_{ij}^{l} norm_{ij}=\Sigma_{j\in N_{i}} \sigma\left(e_{ij}^{l+1}\right)+\varepsilon \hat{e}_{ij}^{l+1}=\sigma(e_{ij}^{l+1}) / norm_{ij} h_{i}^{l+1}=A^l h_{i}^{l}+\Sigma_{j \in N_{i}} \hat{e}_{ij}^{l+1} \odot B^l h_{j}^{l} where :math:`h_{i}^{l}` is node :math:`i` feature of layer :math:`l`, :math:`e_{ij}^{l}` is edge :math:`ij` feature of layer :math:`l`, :math:`\sigma` is sigmoid function, :math:`\varepsilon` is a small fixed constant for numerical stability, :math:`A^l, B^l, C^l, D^l, E^l` are linear layers. Parameters ---------- input_feats : int Input feature size; i.e, the number of dimensions of :math:`h_{i}^{l}`. edge_feats: int Edge feature size; i.e., the number of dimensions of :math:`e_{ij}^{l}`. output_feats : int Output feature size; i.e., the number of dimensions of :math:`h_{i}^{l+1}`. dropout : float, optional Dropout rate on node and edge feature. Default: ``0``. batch_norm : bool, optional Whether to include batch normalization on node and edge feature. Default: ``True``. residual : bool, optional Whether to include residual connections. Default: ``True``. activation : callable activation function/layer or None, optional If not None, apply an activation function to the updated node features. Default: ``F.relu``. Example ------- >>> import dgl >>> import torch as th >>> import torch.nn.functional as F >>> from dgl.nn import GatedGCNConv >>> num_nodes, num_edges = 8, 30 >>> graph = dgl.rand_graph(num_nodes,num_edges) >>> node_feats = th.rand(num_nodes, 20) >>> edge_feats = th.rand(num_edges, 12) >>> gatedGCN = GatedGCNConv(20, 12, 20) >>> new_node_feats, new_edge_feats = gatedGCN(graph, node_feats, edge_feats) >>> new_node_feats.shape, new_edge_feats.shape (torch.Size([8, 20]), torch.Size([30, 20])) """ def __init__( self, input_feats, edge_feats, output_feats, dropout=0, batch_norm=True, residual=True, activation=F.relu, ): super(GatedGCNConv, self).__init__() self.dropout = nn.Dropout(dropout) self.batch_norm = batch_norm self.residual = residual if input_feats != output_feats or edge_feats != output_feats: self.residual = False # Linearly transform the node features. self.A = nn.Linear(input_feats, output_feats, bias=True) self.B = nn.Linear(input_feats, output_feats, bias=True) self.D = nn.Linear(input_feats, output_feats, bias=True) self.E = nn.Linear(input_feats, output_feats, bias=True) # Linearly transform the edge features. self.C = nn.Linear(edge_feats, output_feats, bias=True) # Batch normalization on the node/edge features. self.bn_node = nn.BatchNorm1d(output_feats) self.bn_edge = nn.BatchNorm1d(output_feats) self.activation = activation
[docs] def forward(self, graph, feat, edge_feat): """ Description ----------- Compute gated graph convolution layer. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor The input feature of shape :math:`(N, D_{in})` where :math:`N` is the number of nodes of the graph and :math:`D_{in}` is the input feature size. edge_feat : torch.Tensor The input edge feature of shape :math:`(E, D_{edge})`, where :math:`E` is the number of edges and :math:`D_{edge}` is the size of the edge features. Returns ------- torch.Tensor The output node feature of shape :math:`(N, D_{out})` where :math:`D_{out}` is the output feature size. torch.Tensor The output edge feature of shape :math:`(E, D_{out})` where :math:`D_{out}` is the output feature size. """ with graph.local_scope(): # For residual connection h_in = feat e_in = edge_feat graph.ndata["Ah"] = self.A(feat) graph.ndata["Bh"] = self.B(feat) graph.ndata["Dh"] = self.D(feat) graph.ndata["Eh"] = self.E(feat) graph.edata["Ce"] = self.C(edge_feat) graph.apply_edges(fn.u_add_v("Dh", "Eh", "DEh")) # Get edge feature graph.edata["e"] = graph.edata["DEh"] + graph.edata["Ce"] graph.edata["sigma"] = torch.sigmoid(graph.edata["e"]) graph.update_all( fn.u_mul_e("Bh", "sigma", "m"), fn.sum("m", "sum_sigma_h") ) graph.update_all(fn.copy_e("sigma", "m"), fn.sum("m", "sum_sigma")) graph.ndata["h"] = graph.ndata["Ah"] + graph.ndata[ "sum_sigma_h" ] / (graph.ndata["sum_sigma"] + 1e-6) # Result of graph convolution. feat = graph.ndata["h"] edge_feat = graph.edata["e"] # Batch normalization. if self.batch_norm: feat = self.bn_node(feat) edge_feat = self.bn_edge(edge_feat) # Non-linear activation. if self.activation: feat = self.activation(feat) edge_feat = self.activation(edge_feat) # Residual connection. if self.residual: feat = h_in + feat edge_feat = e_in + edge_feat feat = self.dropout(feat) edge_feat = self.dropout(edge_feat) return feat, edge_feat