"""Tree-structured data.
Including:
- Stanford Sentiment Treebank
"""
from __future__ import absolute_import
from collections import OrderedDict
import networkx as nx
import numpy as np
import os
from .dgl_dataset import DGLBuiltinDataset
from .. import backend as F
from .utils import _get_dgl_url, save_graphs, save_info, load_graphs, \
load_info, deprecate_property
from ..convert import from_networkx
__all__ = ['SST', 'SSTDataset']
[docs]class SSTDataset(DGLBuiltinDataset):
r"""Stanford Sentiment Treebank dataset.
Each sample is the constituency tree of a sentence. The leaf nodes
represent words. The word is a int value stored in the ``x`` feature field.
The non-leaf node has a special value ``PAD_WORD`` in the ``x`` field.
Each node also has a sentiment annotation: 5 classes (very negative,
negative, neutral, positive and very positive). The sentiment label is a
int value stored in the ``y`` feature field.
Official site: `<http://nlp.stanford.edu/sentiment/index.html>`_
Statistics:
- Train examples: 8,544
- Dev examples: 1,101
- Test examples: 2,210
- Number of classes for each node: 5
Parameters
----------
mode : str, optional
Should be one of ['train', 'dev', 'test', 'tiny']
Default: train
glove_embed_file : str, optional
The path to pretrained glove embedding file.
Default: None
vocab_file : str, optional
Optional vocabulary file. If not given, the default vacabulary file is used.
Default: None
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information. Default: True.
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
vocab : OrderedDict
Vocabulary of the dataset
num_classes : int
Number of classes for each node
pretrained_emb: Tensor
Pretrained glove embedding with respect the vocabulary.
vocab_size : int
The size of the vocabulary
Notes
-----
All the samples will be loaded and preprocessed in the memory first.
Examples
--------
>>> # get dataset
>>> train_data = SSTDataset()
>>> dev_data = SSTDataset(mode='dev')
>>> test_data = SSTDataset(mode='test')
>>> tiny_data = SSTDataset(mode='tiny')
>>>
>>> len(train_data)
8544
>>> train_data.num_classes
5
>>> glove_embed = train_data.pretrained_emb
>>> train_data.vocab_size
19536
>>> train_data[0]
Graph(num_nodes=71, num_edges=70,
ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(), dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={})
>>> for tree in train_data:
... input_ids = tree.ndata['x']
... labels = tree.ndata['y']
... mask = tree.ndata['mask']
... # your code here
"""
PAD_WORD = -1 # special pad word id
UNK_WORD = -1 # out-of-vocabulary word id
def __init__(self,
mode='train',
glove_embed_file=None,
vocab_file=None,
raw_dir=None,
force_reload=False,
verbose=False,
transform=None):
assert mode in ['train', 'dev', 'test', 'tiny']
_url = _get_dgl_url('dataset/sst.zip')
self._glove_embed_file = glove_embed_file if mode == 'train' else None
self.mode = mode
self._vocab_file = vocab_file
super(SSTDataset, self).__init__(name='sst',
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
transform=transform)
def process(self):
from nltk.corpus.reader import BracketParseCorpusReader
# load vocab file
self._vocab = OrderedDict()
vocab_file = self._vocab_file if self._vocab_file is not None else os.path.join(self.raw_path, 'vocab.txt')
with open(vocab_file, encoding='utf-8') as vf:
for line in vf.readlines():
line = line.strip()
self._vocab[line] = len(self._vocab)
# filter glove
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
glove_emb = {}
with open(self._glove_embed_file, 'r', encoding='utf-8') as pf:
for line in pf.readlines():
sp = line.split(' ')
if sp[0].lower() in self._vocab:
glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]])
files = ['{}.txt'.format(self.mode)]
corpus = BracketParseCorpusReader(self.raw_path, files)
sents = corpus.parsed_sents(files[0])
# initialize with glove
pretrained_emb = []
fail_cnt = 0
for line in self._vocab.keys():
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
if not line.lower() in glove_emb:
fail_cnt += 1
pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300)))
self._pretrained_emb = None
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
print('Miss word in GloVe {0:.4f}'.format(1.0 * fail_cnt / len(self._pretrained_emb)))
# build trees
self._trees = []
for sent in sents:
self._trees.append(self._build_tree(sent))
def _build_tree(self, root):
g = nx.DiGraph()
def _rec_build(nid, node):
for child in node:
cid = g.number_of_nodes()
if isinstance(child[0], str) or isinstance(child[0], bytes):
# leaf node
word = self.vocab.get(child[0].lower(), self.UNK_WORD)
g.add_node(cid, x=word, y=int(child.label()), mask=1)
else:
g.add_node(cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0)
_rec_build(cid, child)
g.add_edge(cid, nid)
# add root
g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0)
_rec_build(0, root)
ret = from_networkx(g, node_attrs=['x', 'y', 'mask'])
return ret
def has_cache(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
vocab_path = os.path.join(self.save_path, 'vocab.pkl')
return os.path.exists(graph_path) and os.path.exists(vocab_path)
def save(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
save_graphs(graph_path, self._trees)
vocab_path = os.path.join(self.save_path, 'vocab.pkl')
save_info(vocab_path, {'vocab': self.vocab})
if self.pretrained_emb:
emb_path = os.path.join(self.save_path, 'emb.pkl')
save_info(emb_path, {'embed': self.pretrained_emb})
def load(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
vocab_path = os.path.join(self.save_path, 'vocab.pkl')
emb_path = os.path.join(self.save_path, 'emb.pkl')
self._trees = load_graphs(graph_path)[0]
self._vocab = load_info(vocab_path)['vocab']
self._pretrained_emb = None
if os.path.exists(emb_path):
self._pretrained_emb = load_info(emb_path)['embed']
@property
def vocab(self):
r""" Vocabulary
Returns
-------
OrderedDict
"""
return self._vocab
@property
def pretrained_emb(self):
r"""Pre-trained word embedding, if given."""
return self._pretrained_emb
[docs] def __getitem__(self, idx):
r""" Get graph by index
Parameters
----------
idx : int
Returns
-------
:class:`dgl.DGLGraph`
graph structure, word id for each node, node labels and masks.
- ``ndata['x']``: word id of the node
- ``ndata['y']:`` label of the node
- ``ndata['mask']``: 1 if the node is a leaf, otherwise 0
"""
if self._transform is None:
return self._trees[idx]
else:
return self._transform(self._trees[idx])
[docs] def __len__(self):
r"""Number of graphs in the dataset."""
return len(self._trees)
@property
def vocab_size(self):
r"""Vocabulary size."""
return len(self._vocab)
@property
def num_classes(self):
r"""Number of classes for each node."""
return 5
SST = SSTDataset