DropNode

class dgl.transforms.DropNode(p=0.5)[source]

Bases: BaseTransform

Randomly drop nodes, as described in Graph Contrastive Learning with Augmentations.

Parameters:

p (float, optional) – Probability of a node to be dropped.

Example

>>> import dgl
>>> import torch
>>> from dgl import DropNode
>>> transform = DropNode()
>>> g = dgl.rand_graph(5, 20)
>>> g.ndata['h'] = torch.arange(g.num_nodes())
>>> g.edata['h'] = torch.arange(g.num_edges())
>>> new_g = transform(g)
>>> print(new_g)
Graph(num_nodes=3, num_edges=7,
      ndata_schemes={'h': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'h': Scheme(shape=(), dtype=torch.int64)})
>>> print(new_g.ndata['h'])
tensor([0, 1, 2])
>>> print(new_g.edata['h'])
tensor([0, 6, 14, 5, 17, 3, 11])