add_node_property_split

class dgl.data.utils.add_node_property_split(dataset, part_ratios, property_name, ascending=True, random_seed=None)[source]

Bases:

Create a node split with distributional shift based on a given node property, as proposed in Evaluating Robustness and Uncertainty of Graph Models Under Structural Distributional Shifts

It splits the nodes of each graph in the given dataset into 5 non-intersecting parts based on their structural properties. This can be used for transductive node prediction task with distributional shifts.

It considers the in-distribution (ID) and out-of-distribution (OOD) subsets of nodes. The ID subset includes training, validation and testing parts, while the OOD subset includes validation and testing parts. As a result, it creates 5 associated node mask arrays for each graph:

  • 3 for the ID nodes: 'in_train_mask', 'in_valid_mask', 'in_test_mask',

  • and 2 for the OOD nodes: 'out_valid_mask', 'out_test_mask'.

This function implements 3 particular strategies for inducing distributional shifts in graph β€” based on popularity, locality or density.

Parameters:
  • dataset (DGLDataset or list of DGLGraph) – The dataset to induce structural distributional shift.

  • part_ratios (list) – A list of 5 ratio values for training, ID validation, ID test, OOD validation and OOD test parts. The values must sum to 1.0.

  • property_name (str) – The name of the node property to be used, which must be 'popularity', 'locality' or 'density'.

  • ascending (bool, optional) – Whether to sort nodes in the ascending order of the node property, so that nodes with greater values of the property are considered to be OOD (default: True)

  • random_seed (int, optional) – Random seed to fix for the initial permutation of nodes. It is used to create a random order for the nodes that have the same property values or belong to the ID subset. (default: None)

Examples

>>> dataset = dgl.data.AmazonCoBuyComputerDataset()
>>> print('in_valid_mask' in dataset[0].ndata)
False
>>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
>>> property_name = 'popularity'
>>> dgl.data.utils.add_node_property_split(dataset, part_ratios, property_name)
>>> print('in_valid_mask' in dataset[0].ndata)
True