RedditDataset

class dgl.data.RedditDataset(self_loop=False, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

Bases: dgl.data.dgl_dataset.DGLBuiltinDataset

Reddit dataset for community detection (node classification)

    Deprecated since version 0.5.0:
  • graph is deprecated, it is replaced by:

    >>> dataset = RedditDataset()
    >>> graph = dataset[0]
    
  • num_labels is deprecated, it is replaced by:

    >>> dataset = RedditDataset()
    >>> num_classes = dataset.num_classes
    
  • train_mask is deprecated, it is replaced by:

    >>> dataset = RedditDataset()
    >>> graph = dataset[0]
    >>> train_mask = graph.ndata['train_mask']
    
  • val_mask is deprecated, it is replaced by:

    >>> dataset = RedditDataset()
    >>> graph = dataset[0]
    >>> val_mask = graph.ndata['val_mask']
    
  • test_mask is deprecated, it is replaced by:

    >>> dataset = RedditDataset()
    >>> graph = dataset[0]
    >>> test_mask = graph.ndata['test_mask']
    
  • features is deprecated, it is replaced by:

    >>> dataset = RedditDataset()
    >>> graph = dataset[0]
    >>> features = graph.ndata['feat']
    
  • labels is deprecated, it is replaced by:

    >>> dataset = RedditDataset()
    >>> graph = dataset[0]
    >>> labels = graph.ndata['label']
    

This is a graph dataset from Reddit posts made in the month of September, 2014. The node label in this case is the community, or “subreddit”, that a post belongs to. The authors sampled 50 large communities and built a post-to-post graph, connecting posts if the same user comments on both. In total this dataset contains 232,965 posts with an average degree of 492. We use the first 20 days for training and the remaining days for testing (with 30% used for validation).

Reference: http://snap.stanford.edu/graphsage/

Statistics

  • Nodes: 232,965

  • Edges: 114,615,892

  • Node feature size: 602

  • Number of training samples: 153,431

  • Number of validation samples: 23,831

  • Number of test samples: 55,703

Parameters
  • self_loop (bool) – Whether load dataset with self loop connections. Default: False

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_classes

Number of classes for each node

Type

int

graph

Graph of the dataset

Type

dgl.DGLGraph

num_labels

Number of classes for each node

Type

int

train_mask

Mask of training nodes

Type

numpy.ndarray

val_mask

Mask of validation nodes

Type

numpy.ndarray

test_mask

Mask of test nodes

Type

numpy.ndarray

features

Node features

Type

Tensor

labels

Node labels

Type

Tensor

Examples

>>> data = RedditDataset()
>>> g = data[0]
>>> num_classes = data.num_classes
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
>>>
>>> # get data split
>>> train_mask = g.ndata['train_mask']
>>> val_mask = g.ndata['val_mask']
>>> test_mask = g.ndata['test_mask']
>>>
>>> # get labels
>>> label = g.ndata['label']
>>>
>>> # Train, Validation and Test
__getitem__(idx)[source]

Get graph by index

Parameters

idx (int) – Item index

Returns

graph structure, node labels, node features and splitting masks:

  • ndata['label']: node label

  • ndata['feat']: node feature

  • ndata['train_mask']: mask for training node set

  • ndata['val_mask']: mask for validation node set

  • ndata['test_mask']: mask for test node set

Return type

dgl.DGLGraph

__len__()[source]

Number of graphs in the dataset