Source code for dgl.nn.pytorch.conv.sageconv

"""Torch Module for GraphSAGE layer"""
# pylint: disable= no-member, arguments-differ, invalid-name
import torch
from torch import nn
from torch.nn import functional as F

from .... import function as fn
from ....utils import expand_as_pair, check_eq_shape, dgl_warning


[docs]class SAGEConv(nn.Module): r""" Description ----------- GraphSAGE layer from paper `Inductive Representation Learning on Large Graphs <https://arxiv.org/pdf/1706.02216.pdf>`__. .. math:: h_{\mathcal{N}(i)}^{(l+1)} &= \mathrm{aggregate} \left(\{h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right) h_{i}^{(l+1)} &= \sigma \left(W \cdot \mathrm{concat} (h_{i}^{l}, h_{\mathcal{N}(i)}^{l+1}) \right) h_{i}^{(l+1)} &= \mathrm{norm}(h_{i}^{l}) If a weight tensor on each edge is provided, the aggregation becomes: .. math:: h_{\mathcal{N}(i)}^{(l+1)} = \mathrm{aggregate} \left(\{e_{ji} h_{j}^{l}, \forall j \in \mathcal{N}(i) \}\right) where :math:`e_{ji}` is the scalar weight on the edge from node :math:`j` to node :math:`i`. Please make sure that :math:`e_{ji}` is broadcastable with :math:`h_j^{l}`. Parameters ---------- in_feats : int, or pair of ints Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`. SAGEConv can be applied on homogeneous graph and unidirectional `bipartite graph <https://docs.dgl.ai/generated/dgl.bipartite.html?highlight=bipartite>`__. If the layer applies on a unidirectional bipartite graph, ``in_feats`` specifies the input feature size on both the source and destination nodes. If a scalar is given, the source and destination node feature size would take the same value. If aggregator type is ``gcn``, the feature size of source and destination nodes are required to be the same. out_feats : int Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`. feat_drop : float Dropout rate on features, default: ``0``. aggregator_type : str Aggregator type to use (``mean``, ``gcn``, ``pool``, ``lstm``). bias : bool If True, adds a learnable bias to the output. Default: ``True``. norm : callable activation function/layer or None, optional If not None, applies normalization to the updated node features. activation : callable activation function/layer or None, optional If not None, applies an activation function to the updated node features. Default: ``None``. Examples -------- >>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import SAGEConv >>> # Case 1: Homogeneous graph >>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> g = dgl.add_self_loop(g) >>> feat = th.ones(6, 10) >>> conv = SAGEConv(10, 2, 'pool') >>> res = conv(g, feat) >>> res tensor([[-1.0888, -2.1099], [-1.0888, -2.1099], [-1.0888, -2.1099], [-1.0888, -2.1099], [-1.0888, -2.1099], [-1.0888, -2.1099]], grad_fn=<AddBackward0>) >>> # Case 2: Unidirectional bipartite graph >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.bipartite((u, v)) >>> u_fea = th.rand(2, 5) >>> v_fea = th.rand(4, 10) >>> conv = SAGEConv((5, 10), 2, 'mean') >>> res = conv(g, (u_fea, v_fea)) >>> res tensor([[ 0.3163, 3.1166], [ 0.3866, 2.5398], [ 0.5873, 1.6597], [-0.2502, 2.8068]], grad_fn=<AddBackward0>) """ def __init__(self, in_feats, out_feats, aggregator_type, feat_drop=0., bias=True, norm=None, activation=None): super(SAGEConv, self).__init__() self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._out_feats = out_feats self._aggre_type = aggregator_type self.norm = norm self.feat_drop = nn.Dropout(feat_drop) self.activation = activation # aggregator type: mean/pool/lstm/gcn if aggregator_type == 'pool': self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats) if aggregator_type == 'lstm': self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True) if aggregator_type != 'gcn': self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=False) self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=False) if bias: self.bias = nn.parameter.Parameter(torch.zeros(self._out_feats)) else: self.register_buffer('bias', None) self.reset_parameters() def reset_parameters(self): r""" Description ----------- Reinitialize learnable parameters. Note ---- The linear weights :math:`W^{(l)}` are initialized using Glorot uniform initialization. The LSTM module is using xavier initialization method for its weights. """ gain = nn.init.calculate_gain('relu') if self._aggre_type == 'pool': nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain) if self._aggre_type == 'lstm': self.lstm.reset_parameters() if self._aggre_type != 'gcn': nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain) def _compatibility_check(self): """Address the backward compatibility issue brought by #2747""" if not hasattr(self, 'bias'): dgl_warning("You are loading a GraphSAGE model trained from a old version of DGL, " "DGL automatically convert it to be compatible with latest version.") bias = self.fc_neigh.bias self.fc_neigh.bias = None if hasattr(self, 'fc_self'): if bias is not None: bias = bias + self.fc_self.bias self.fc_self.bias = None self.bias = bias def _lstm_reducer(self, nodes): """LSTM reducer NOTE(zihao): lstm reducer with default schedule (degree bucketing) is slow, we could accelerate this with degree padding in the future. """ m = nodes.mailbox['m'] # (B, L, D) batch_size = m.shape[0] h = (m.new_zeros((1, batch_size, self._in_src_feats)), m.new_zeros((1, batch_size, self._in_src_feats))) _, (rst, _) = self.lstm(m, h) return {'neigh': rst.squeeze(0)}
[docs] def forward(self, graph, feat, edge_weight=None): r""" Description ----------- Compute GraphSAGE layer. Parameters ---------- graph : DGLGraph The graph. feat : torch.Tensor or pair of torch.Tensor If a torch.Tensor is given, it represents the input feature of shape :math:`(N, D_{in})` where :math:`D_{in}` is size of input feature, :math:`N` is the number of nodes. If a pair of torch.Tensor is given, the pair must contain two tensors of shape :math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`. edge_weight : torch.Tensor, optional Optional tensor on the edge. If given, the convolution will weight with regard to the message. Returns ------- torch.Tensor The output feature of shape :math:`(N_{dst}, D_{out})` where :math:`N_{dst}` is the number of destination nodes in the input graph, math:`D_{out}` is size of output feature. """ self._compatibility_check() with graph.local_scope(): if isinstance(feat, tuple): feat_src = self.feat_drop(feat[0]) feat_dst = self.feat_drop(feat[1]) else: feat_src = feat_dst = self.feat_drop(feat) if graph.is_block: feat_dst = feat_src[:graph.number_of_dst_nodes()] msg_fn = fn.copy_src('h', 'm') if edge_weight is not None: assert edge_weight.shape[0] == graph.number_of_edges() graph.edata['_edge_weight'] = edge_weight msg_fn = fn.u_mul_e('h', '_edge_weight', 'm') h_self = feat_dst # Handle the case of graphs without edges if graph.number_of_edges() == 0: graph.dstdata['neigh'] = torch.zeros( feat_dst.shape[0], self._in_src_feats).to(feat_dst) # Determine whether to apply linear transformation before message passing A(XW) lin_before_mp = self._in_src_feats > self._out_feats # Message Passing if self._aggre_type == 'mean': graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src graph.update_all(msg_fn, fn.mean('m', 'neigh')) h_neigh = graph.dstdata['neigh'] if not lin_before_mp: h_neigh = self.fc_neigh(h_neigh) elif self._aggre_type == 'gcn': check_eq_shape(feat) graph.srcdata['h'] = self.fc_neigh(feat_src) if lin_before_mp else feat_src if isinstance(feat, tuple): # heterogeneous graph.dstdata['h'] = self.fc_neigh(feat_dst) if lin_before_mp else feat_dst else: if graph.is_block: graph.dstdata['h'] = graph.srcdata['h'][:graph.num_dst_nodes()] else: graph.dstdata['h'] = graph.srcdata['h'] graph.update_all(msg_fn, fn.sum('m', 'neigh')) # divide in_degrees degs = graph.in_degrees().to(feat_dst) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) if not lin_before_mp: h_neigh = self.fc_neigh(h_neigh) elif self._aggre_type == 'pool': graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) graph.update_all(msg_fn, fn.max('m', 'neigh')) h_neigh = self.fc_neigh(graph.dstdata['neigh']) elif self._aggre_type == 'lstm': graph.srcdata['h'] = feat_src graph.update_all(msg_fn, self._lstm_reducer) h_neigh = self.fc_neigh(graph.dstdata['neigh']) else: raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) # GraphSAGE GCN does not require fc_self. if self._aggre_type == 'gcn': rst = h_neigh else: rst = self.fc_self(h_self) + h_neigh # bias term if self.bias is not None: rst = rst + self.bias # activation if self.activation is not None: rst = self.activation(rst) # normalization if self.norm is not None: rst = self.norm(rst) return rst