SSTDataset

class dgl.data.SSTDataset(mode='train', glove_embed_file=None, vocab_file=None, raw_dir=None, force_reload=False, verbose=False, transform=None)[source]

Bases: dgl.data.dgl_dataset.DGLBuiltinDataset

Stanford Sentiment Treebank dataset.

Each sample is the constituency tree of a sentence. The leaf nodes represent words. The word is a int value stored in the x feature field. The non-leaf node has a special value PAD_WORD in the x field. Each node also has a sentiment annotation: 5 classes (very negative, negative, neutral, positive and very positive). The sentiment label is a int value stored in the y feature field. Official site: http://nlp.stanford.edu/sentiment/index.html

Statistics:

  • Train examples: 8,544

  • Dev examples: 1,101

  • Test examples: 2,210

  • Number of classes for each node: 5

Parameters
  • mode (str, optional) – Should be one of [‘train’, ‘dev’, ‘test’, ‘tiny’] Default: train

  • glove_embed_file (str, optional) – The path to pretrained glove embedding file. Default: None

  • vocab_file (str, optional) – Optional vocabulary file. If not given, the default vacabulary file is used. Default: None

  • raw_dir (str) – Raw file directory to download/contains the input data directory. Default: ~/.dgl/

  • force_reload (bool) – Whether to reload the dataset. Default: False

  • verbose (bool) – Whether to print out progress information. Default: True.

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

vocab

Vocabulary of the dataset

Type

OrderedDict

num_classes

Number of classes for each node

Type

int

pretrained_emb

Pretrained glove embedding with respect the vocabulary.

Type

Tensor

vocab_size

The size of the vocabulary

Type

int

Notes

All the samples will be loaded and preprocessed in the memory first.

Examples

>>> # get dataset
>>> train_data = SSTDataset()
>>> dev_data = SSTDataset(mode='dev')
>>> test_data = SSTDataset(mode='test')
>>> tiny_data = SSTDataset(mode='tiny')
>>>
>>> len(train_data)
8544
>>> train_data.num_classes
5
>>> glove_embed = train_data.pretrained_emb
>>> train_data.vocab_size
19536
>>> train_data[0]
Graph(num_nodes=71, num_edges=70,
  ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
  edata_schemes={})
>>> for tree in train_data:
...     input_ids = tree.ndata['x']
...     labels = tree.ndata['y']
...     mask = tree.ndata['mask']
...     # your code here
__getitem__(idx)[source]

Get graph by index

Parameters

idx (int) –

Returns

graph structure, word id for each node, node labels and masks.

  • ndata['x']: word id of the node

  • ndata['y']: label of the node

  • ndata['mask']: 1 if the node is a leaf, otherwise 0

Return type

dgl.DGLGraph

__len__()[source]

Number of graphs in the dataset.