AsGraphPredDataset

class dgl.data.AsGraphPredDataset(dataset, split_ratio=None, **kwargs)[source]

Bases: dgl.data.dgl_dataset.DGLDataset

Repurpose a dataset for standard graph property prediction task.

The created dataset will include data needed for graph property prediction. Currently it only supports homogeneous graphs.

The class converts a given dataset into a new dataset object such that:

  • It stores len(dataset) graphs.

  • The i-th graph and its label is accessible from dataset[i].

The class will generate a train/val/test split if split_ratio is provided. The generated split will be cached to disk for fast re-loading. If the provided split ratio differs from the cached one, it will re-process the dataset properly.

Parameters
  • dataset (DGLDataset) – The dataset to be converted.

  • split_ratio ((float, float, float), optional) – Split ratios for training, validation and test sets. They must sum to one.

num_tasks

Number of tasks to predict.

Type

int

num_classes

Number of classes to predict per task, None for regression datasets.

Type

int

train_idx

An 1-D integer tensor of training node IDs.

Type

Tensor

val_idx

An 1-D integer tensor of validation node IDs.

Type

Tensor

test_idx

An 1-D integer tensor of test node IDs.

Type

Tensor

node_feat_size

Input node feature size, None if not applicable.

Type

int

edge_feat_size

Input edge feature size, None if not applicable.

Type

int

Examples

>>> from dgl.data import AsGraphPredDataset
>>> from ogb.graphproppred import DglGraphPropPredDataset
>>> dataset = DglGraphPropPredDataset(name='ogbg-molhiv')
>>> new_dataset = AsGraphPredDataset(dataset)
>>> print(new_dataset)
Dataset("ogbg-molhiv-as-graphpred", num_graphs=41127, save_path=...)
>>> print(len(new_dataset))
41127
>>> print(new_dataset[0])
(Graph(num_nodes=19, num_edges=40,
       ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.int64)}
       edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.int64)}), tensor([0]))
__getitem__(idx)[source]

Gets the data object at index.

__len__()[source]

The number of examples in the dataset.