add_node_property_split¶
-
class
dgl.data.utils.
add_node_property_split
(dataset, part_ratios, property_name, ascending=True, random_seed=None)[source]¶ Create a node split with distributional shift based on a given node property, as proposed in Evaluating Robustness and Uncertainty of Graph Models Under Structural Distributional Shifts
It splits the nodes of each graph in the given dataset into 5 non-intersecting parts based on their structural properties. This can be used for transductive node prediction task with distributional shifts.
It considers the in-distribution (ID) and out-of-distribution (OOD) subsets of nodes. The ID subset includes training, validation and testing parts, while the OOD subset includes validation and testing parts. As a result, it creates 5 associated node mask arrays for each graph:
3 for the ID nodes:
'in_train_mask'
,'in_valid_mask'
,'in_test_mask'
,and 2 for the OOD nodes:
'out_valid_mask'
,'out_test_mask'
.
This function implements 3 particular strategies for inducing distributional shifts in graph — based on popularity, locality or density.
- Parameters
dataset (
DGLDataset
or list ofDGLGraph
) – The dataset to induce structural distributional shift.part_ratios (list) – A list of 5 ratio values for training, ID validation, ID test, OOD validation and OOD test parts. The values must sum to 1.0.
property_name (str) – The name of the node property to be used, which must be
'popularity'
,'locality'
or'density'
.ascending (bool, optional) – Whether to sort nodes in the ascending order of the node property, so that nodes with greater values of the property are considered to be OOD (default: True)
random_seed (int, optional) – Random seed to fix for the initial permutation of nodes. It is used to create a random order for the nodes that have the same property values or belong to the ID subset. (default: None)
Examples
>>> dataset = dgl.data.AmazonCoBuyComputerDataset() >>> print('in_valid_mask' in dataset[0].ndata) False >>> part_ratios = [0.3, 0.1, 0.1, 0.3, 0.2] >>> property_name = 'popularity' >>> dgl.data.utils.add_node_property_split(dataset, part_ratios, property_name) >>> print('in_valid_mask' in dataset[0].ndata) True