RowFeatNormalizer¶
-
class
dgl.transforms.
RowFeatNormalizer
(subtract_min=False, node_feat_names=None, edge_feat_names=None)[source]¶ Bases:
dgl.transforms.module.BaseTransform
Row-normalizes the features given in
node_feat_names
andedge_feat_names
.The row normalization formular is:
\[x = \frac{x}{\sum_i x_i}\]where \(x\) denotes a row of the feature tensor.
- Parameters
subtract_min (bool) – If True, the minimum value of whole feature tensor will be subtracted before normalization. Default: False. Subtraction will make all values non-negative. If all values are negative, after normalisation, the sum of each row of the feature tensor will be 1.
node_feat_names (list[str], optional) – The names of the node feature tensors to be row-normalized. Default: None, which will not normalize any node feature tensor.
edge_feat_names (list[str], optional) – The names of the edge feature tensors to be row-normalized. Default: None, which will not normalize any edge feature tensor.
Example
The following example uses PyTorch backend.
>>> import dgl >>> import torch >>> from dgl import RowFeatNormalizer
Case1: Row normalize features of a homogeneous graph.
>>> transform = RowFeatNormalizer(subtract_min=True, ... node_feat_names=['h'], edge_feat_names=['w']) >>> g = dgl.rand_graph(5, 20) >>> g.ndata['h'] = torch.randn((g.num_nodes(), 5)) >>> g.edata['w'] = torch.randn((g.num_edges(), 5)) >>> g = transform(g) >>> print(g.ndata['h'].sum(1)) tensor([1., 1., 1., 1., 1.]) >>> print(g.edata['w'].sum(1)) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
Case2: Row normalize features of a heterogeneous graph.
>>> g = dgl.heterograph({ ... ('user', 'follows', 'user'): (torch.tensor([1, 2]), torch.tensor([3, 4])), ... ('player', 'plays', 'game'): (torch.tensor([2, 2]), torch.tensor([1, 1])) ... }) >>> g.ndata['h'] = {'game': torch.randn(2, 5), 'player': torch.randn(3, 5)} >>> g.edata['w'] = { ... ('user', 'follows', 'user'): torch.randn(2, 5), ... ('player', 'plays', 'game'): torch.randn(2, 5) ... } >>> g = transform(g) >>> print(g.ndata['h']['game'].sum(1), g.ndata['h']['player'].sum(1)) tensor([1., 1.]) tensor([1., 1., 1.]) >>> print(g.edata['w'][('user', 'follows', 'user')].sum(1), ... g.edata['w'][('player', 'plays', 'game')].sum(1)) tensor([1., 1.]) tensor([1., 1.])