"""Datasets used in How Powerful Are Graph Neural Networks?
(chen jun)
Datasets include:
MUTAG, COLLAB, IMDBBINARY, IMDBMULTI, NCI1, PROTEINS, PTC, REDDITBINARY, REDDITMULTI5K
https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip
"""
import os
import numpy as np
from .. import backend as F
from .dgl_dataset import DGLBuiltinDataset
from .utils import loadtxt, save_graphs, load_graphs, save_info, load_info, download, extract_archive
from ..utils import retry_method_with_fix
from ..convert import graph as dgl_graph
[docs]class GINDataset(DGLBuiltinDataset):
"""Dataset Class for `How Powerful Are Graph Neural Networks? <https://arxiv.org/abs/1810.00826>`_.
This is adapted from `<https://github.com/weihua916/powerful-gnns/blob/master/dataset.zip>`_.
The class provides an interface for nine datasets used in the paper along with the paper-specific
settings. The datasets are ``'MUTAG'``, ``'COLLAB'``, ``'IMDBBINARY'``, ``'IMDBMULTI'``,
``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, ``'REDDITBINARY'``, ``'REDDITMULTI5K'``.
If ``degree_as_nlabel`` is set to ``False``, then ``ndata['label']`` stores the provided node label,
otherwise ``ndata['label']`` stores the node in-degrees.
For graphs that have node attributes, ``ndata['attr']`` stores the node attributes.
For graphs that have no attribute, ``ndata['attr']`` stores the corresponding one-hot encoding
of ``ndata['label']``.
Parameters
---------
name: str
dataset name, one of
(``'MUTAG'``, ``'COLLAB'``, \
``'IMDBBINARY'``, ``'IMDBMULTI'``, \
``'NCI1'``, ``'PROTEINS'``, ``'PTC'``, \
``'REDDITBINARY'``, ``'REDDITMULTI5K'``)
self_loop: bool
add self to self edge if true
degree_as_nlabel: bool
take node degree as label and feature if true
Examples
--------
>>> data = GINDataset(name='MUTAG', self_loop=False)
The dataset instance is an iterable
>>> len(data)
188
>>> g, label = data[128]
>>> g
Graph(num_nodes=13, num_edges=26,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float32)}
edata_schemes={})
>>> label
tensor(1)
Batch the graphs and labels for mini-batch training
>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
>>> batched_labels = torch.tensor(labels)
>>> batched_graphs
Graph(num_nodes=330, num_edges=748,
ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float32)}
edata_schemes={})
"""
def __init__(self, name, self_loop, degree_as_nlabel=False,
raw_dir=None, force_reload=False, verbose=False):
self._name = name # MUTAG
gin_url = 'https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip'
self.ds_name = 'nig'
self.self_loop = self_loop
self.graphs = []
self.labels = []
# relabel
self.glabel_dict = {}
self.nlabel_dict = {}
self.elabel_dict = {}
self.ndegree_dict = {}
# global num
self.N = 0 # total graphs number
self.n = 0 # total nodes number
self.m = 0 # total edges number
# global num of classes
self.gclasses = 0
self.nclasses = 0
self.eclasses = 0
self.dim_nfeats = 0
# flags
self.degree_as_nlabel = degree_as_nlabel
self.nattrs_flag = False
self.nlabels_flag = False
super(GINDataset, self).__init__(name=name, url=gin_url, hash_key=(name, self_loop, degree_as_nlabel),
raw_dir=raw_dir, force_reload=force_reload, verbose=verbose)
@property
def raw_path(self):
return os.path.join(self.raw_dir, 'GINDataset')
def download(self):
r""" Automatically download data and extract it.
"""
zip_file_path = os.path.join(self.raw_dir, 'GINDataset.zip')
download(self.url, path=zip_file_path)
extract_archive(zip_file_path, self.raw_path)
[docs] def __len__(self):
"""Return the number of graphs in the dataset."""
return len(self.graphs)
[docs] def __getitem__(self, idx):
"""Get the idx-th sample.
Parameters
---------
idx : int
The sample index.
Returns
-------
(:class:`dgl.Graph`, Tensor)
The graph and its label.
"""
return self.graphs[idx], self.labels[idx]
def _file_path(self):
return os.path.join(self.raw_dir, "GINDataset", 'dataset', self.name, "{}.txt".format(self.name))
def process(self):
""" Loads input dataset from dataset/NAME/NAME.txt file
"""
if self.verbose:
print('loading data...')
self.file = self._file_path()
with open(self.file, 'r') as f:
# line_1 == N, total number of graphs
self.N = int(f.readline().strip())
for i in range(self.N):
if (i + 1) % 10 == 0 and self.verbose is True:
print('processing graph {}...'.format(i + 1))
grow = f.readline().strip().split()
# line_2 == [n_nodes, l] is equal to
# [node number of a graph, class label of a graph]
n_nodes, glabel = [int(w) for w in grow]
# relabel graphs
if glabel not in self.glabel_dict:
mapped = len(self.glabel_dict)
self.glabel_dict[glabel] = mapped
self.labels.append(self.glabel_dict[glabel])
g = dgl_graph(([], []))
g.add_nodes(n_nodes)
nlabels = [] # node labels
nattrs = [] # node attributes if it has
m_edges = 0
for j in range(n_nodes):
nrow = f.readline().strip().split()
# handle edges and attributes(if has)
tmp = int(nrow[1]) + 2 # tmp == 2 + #edges
if tmp == len(nrow):
# no node attributes
nrow = [int(w) for w in nrow]
elif tmp > len(nrow):
nrow = [int(w) for w in nrow[:tmp]]
nattr = [float(w) for w in nrow[tmp:]]
nattrs.append(nattr)
else:
raise Exception('edge number is incorrect!')
# relabel nodes if it has labels
# if it doesn't have node labels, then every nrow[0]==0
if not nrow[0] in self.nlabel_dict:
mapped = len(self.nlabel_dict)
self.nlabel_dict[nrow[0]] = mapped
nlabels.append(self.nlabel_dict[nrow[0]])
m_edges += nrow[1]
g.add_edges(j, nrow[2:])
# add self loop
if self.self_loop:
m_edges += 1
g.add_edges(j, j)
if (j + 1) % 10 == 0 and self.verbose is True:
print(
'processing node {} of graph {}...'.format(
j + 1, i + 1))
print('this node has {} edgs.'.format(
nrow[1]))
if nattrs != []:
nattrs = np.stack(nattrs)
g.ndata['attr'] = F.tensor(nattrs, F.float32)
self.nattrs_flag = True
g.ndata['label'] = F.tensor(nlabels)
if len(self.nlabel_dict) > 1:
self.nlabels_flag = True
assert g.number_of_nodes() == n_nodes
# update statistics of graphs
self.n += n_nodes
self.m += m_edges
self.graphs.append(g)
self.labels = F.tensor(self.labels)
# if no attr
if not self.nattrs_flag:
if self.verbose:
print('there are no node features in this dataset!')
# generate node attr by node degree
if self.degree_as_nlabel:
if self.verbose:
print('generate node features by node degree...')
for g in self.graphs:
# actually this label shouldn't be updated
# in case users want to keep it
# but usually no features means no labels, fine.
g.ndata['label'] = g.in_degrees()
# extracting unique node labels
# in case the labels/degrees are not continuous number
nlabel_set = set([])
for g in self.graphs:
nlabel_set = nlabel_set.union(
set([F.as_scalar(nl) for nl in g.ndata['label']]))
nlabel_set = list(nlabel_set)
is_label_valid = all([label in self.nlabel_dict for label in nlabel_set])
if is_label_valid and len(nlabel_set) == np.max(nlabel_set) + 1 and np.min(nlabel_set) == 0:
# Note this is different from the author's implementation. In weihua916's implementation,
# the labels are relabeled anyway. But here we didn't relabel it if the labels are contiguous
# to make it consistent with the original dataset
label2idx = self.nlabel_dict
else:
label2idx = {
nlabel_set[i]: i
for i in range(len(nlabel_set))
}
# generate node attr by node label
for g in self.graphs:
attr = np.zeros((
g.number_of_nodes(), len(label2idx)))
attr[range(g.number_of_nodes()), [label2idx[nl]
for nl in F.asnumpy(g.ndata['label']).tolist()]] = 1
g.ndata['attr'] = F.tensor(attr, F.float32)
# after load, get the #classes and #dim
self.gclasses = len(self.glabel_dict)
self.nclasses = len(self.nlabel_dict)
self.eclasses = len(self.elabel_dict)
self.dim_nfeats = len(self.graphs[0].ndata['attr'][0])
if self.verbose:
print('Done.')
print(
"""
-------- Data Statistics --------'
#Graphs: %d
#Graph Classes: %d
#Nodes: %d
#Node Classes: %d
#Node Features Dim: %d
#Edges: %d
#Edge Classes: %d
Avg. of #Nodes: %.2f
Avg. of #Edges: %.2f
Graph Relabeled: %s
Node Relabeled: %s
Degree Relabeled(If degree_as_nlabel=True): %s \n """ % (
self.N, self.gclasses, self.n, self.nclasses,
self.dim_nfeats, self.m, self.eclasses,
self.n / self.N, self.m / self.N, self.glabel_dict,
self.nlabel_dict, self.ndegree_dict))
def save(self):
graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
label_dict = {'labels': self.labels}
info_dict = {'N': self.N,
'n': self.n,
'm': self.m,
'self_loop': self.self_loop,
'gclasses': self.gclasses,
'nclasses': self.nclasses,
'eclasses': self.eclasses,
'dim_nfeats': self.dim_nfeats,
'degree_as_nlabel': self.degree_as_nlabel,
'glabel_dict': self.glabel_dict,
'nlabel_dict': self.nlabel_dict,
'elabel_dict': self.elabel_dict,
'ndegree_dict': self.ndegree_dict}
save_graphs(str(graph_path), self.graphs, label_dict)
save_info(str(info_path), info_dict)
def load(self):
graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
graphs, label_dict = load_graphs(str(graph_path))
info_dict = load_info(str(info_path))
self.graphs = graphs
self.labels = label_dict['labels']
self.N = info_dict['N']
self.n = info_dict['n']
self.m = info_dict['m']
self.self_loop = info_dict['self_loop']
self.gclasses = info_dict['gclasses']
self.nclasses = info_dict['nclasses']
self.eclasses = info_dict['eclasses']
self.dim_nfeats = info_dict['dim_nfeats']
self.glabel_dict = info_dict['glabel_dict']
self.nlabel_dict = info_dict['nlabel_dict']
self.elabel_dict = info_dict['elabel_dict']
self.ndegree_dict = info_dict['ndegree_dict']
self.degree_as_nlabel = info_dict['degree_as_nlabel']
def has_cache(self):
graph_path = os.path.join(
self.save_path, 'gin_{}_{}.bin'.format(self.name, self.hash))
info_path = os.path.join(
self.save_path, 'gin_{}_{}.pkl'.format(self.name, self.hash))
if os.path.exists(graph_path) and os.path.exists(info_path):
return True
return False