EdgeWeightNormΒΆ

class dgl.nn.pytorch.conv.EdgeWeightNorm(norm='both', eps=0.0)[source]ΒΆ

Bases: torch.nn.modules.module.Module

This module normalizes positive scalar edge weights on a graph following the form in GCN.

Mathematically, setting norm='both' yields the following normalization term:

\[c_{ji} = (\sqrt{\sum_{k\in\mathcal{N}(j)}e_{jk}}\sqrt{\sum_{k\in\mathcal{N}(i)}e_{ki}})\]

And, setting norm='right' yields the following normalization term:

\[c_{ji} = (\sum_{k\in\mathcal{N}(i)}e_{ki})\]

where \(e_{ji}\) is the scalar weight on the edge from node \(j\) to node \(i\).

The module returns the normalized weight \(e_{ji} / c_{ji}\).

Parameters
  • norm (str, optional) – The normalizer as specified above. Default is β€˜both’.

  • eps (float, optional) – A small offset value in the denominator. Default is 0.

Examples

>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import EdgeWeightNorm, GraphConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> g = dgl.add_self_loop(g)
>>> feat = th.ones(6, 10)
>>> edge_weight = th.tensor([0.5, 0.6, 0.4, 0.7, 0.9, 0.1, 1, 1, 1, 1, 1, 1])
>>> norm = EdgeWeightNorm(norm='both')
>>> norm_edge_weight = norm(g, edge_weight)
>>> conv = GraphConv(10, 2, norm='none', weight=True, bias=True)
>>> res = conv(g, feat, edge_weight=norm_edge_weight)
>>> print(res)
tensor([[-1.1849, -0.7525],
        [-1.3514, -0.8582],
        [-1.2384, -0.7865],
        [-1.9949, -1.2669],
        [-1.3658, -0.8674],
        [-0.8323, -0.5286]], grad_fn=<AddBackward0>)
forward(graph, edge_weight)[source]ΒΆ

Compute normalized edge weight for the GCN model.

Parameters
  • graph (DGLGraph) – The graph.

  • edge_weight (torch.Tensor) – Unnormalized scalar weights on the edges. The shape is expected to be \((|E|)\).

Returns

The normalized edge weight.

Return type

torch.Tensor

Raises

DGLError – Case 1: The edge weight is multi-dimensional. Currently this module only supports a scalar weight on each edge. Case 2: The edge weight has non-positive values with norm='both'. This will trigger square root and division by a non-positive number.