AMDataset

class dgl.data.AMDataset(print_every=10000, insert_reverse=True, raw_dir=None, force_reload=False, verbose=True, transform=None)[source]

Bases: dgl.data.rdf.RDFGraphDataset

AM dataset. for node classification task

Namespace convention:

  • Instance: http://purl.org/collections/nl/am/<type>-<id>

  • Relation: http://purl.org/collections/nl/am/<name>

We ignored all literal nodes and the relations connecting them in the output graph.

AM dataset statistics:

  • Nodes: 881680

  • Edges: 5668682 (including reverse edges)

  • Target Category: proxy

  • Number of Classes: 11

  • Label Split:

    • Train: 802

    • Test: 198

Parameters
  • print_every (int) – Preprocessing log for every X tuples. Default: 10000.

  • insert_reverse (bool) – If true, add reverse edge and reverse relations to the final graph. Default: True.

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_classes

Number of classes to predict

Type

int

predict_category

The entity category (node type) that has labels for prediction

Type

str

Examples

>>> dataset = dgl.data.rdf.AMDataset()
>>> graph = dataset[0]
>>> category = dataset.predict_category
>>> num_classes = dataset.num_classes
>>>
>>> train_mask = g.nodes[category].data['train_mask']
>>> test_mask = g.nodes[category].data['test_mask']
>>> label = g.nodes[category].data['label']
__getitem__(idx)[source]

Gets the graph object

Parameters

idx (int) – Item index, AMDataset has only one graph object

Returns

The graph contains:

  • ndata['train_mask']: mask for training node set

  • ndata['test_mask']: mask for testing node set

  • ndata['label']: node labels

Return type

dgl.DGLGraph

__len__()[source]

The number of graphs in the dataset.

Returns

Return type

int