"""Basic DGL Dataset
"""
from __future__ import absolute_import
import abc
import hashlib
import os
import traceback
from ..utils import retry_method_with_fix
from .utils import download, extract_archive, get_download_dir, makedirs
[docs]class DGLDataset(object):
r"""The basic DGL dataset for creating graph datasets.
This class defines a basic template class for DGL Dataset.
The following steps will be executed automatically:
1. Check whether there is a dataset cache on disk
(already processed and stored on the disk) by
invoking ``has_cache()``. If true, goto 5.
2. Call ``download()`` to download the data if ``url`` is not None.
3. Call ``process()`` to process the data.
4. Call ``save()`` to save the processed dataset on disk and goto 6.
5. Call ``load()`` to load the processed dataset from disk.
6. Done.
Users can overwite these functions with their
own data processing logic.
Parameters
----------
name : str
Name of the dataset
url : str
Url to download the raw dataset. Default: None
raw_dir : str
Specifying the directory that will store the
downloaded data or the directory that
already stores the input data.
Default: ~/.dgl/
save_dir : str
Directory to save the processed dataset.
Default: same as raw_dir
hash_key : tuple
A tuple of values as the input for the hash function.
Users can distinguish instances (and their caches on the disk)
from the same dataset class by comparing the hash values.
Default: (), the corresponding hash value is ``'f9065fa7'``.
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
Attributes
----------
url : str
The URL to download the dataset
name : str
The dataset name
raw_dir : str
Directory to store all the downloaded raw datasets.
raw_path : str
Path to the downloaded raw dataset folder. An alias for
``os.path.join(self.raw_dir, self.name)``.
save_dir : str
Directory to save all the processed datasets.
save_path : str
Path to the processed dataset folder. An alias for
``os.path.join(self.save_dir, self.name)``.
verbose : bool
Whether to print more runtime information.
hash : str
Hash value for the dataset and the setting.
"""
def __init__(
self,
name,
url=None,
raw_dir=None,
save_dir=None,
hash_key=(),
force_reload=False,
verbose=False,
transform=None,
):
self._name = name
self._url = url
self._force_reload = force_reload
self._verbose = verbose
self._hash_key = hash_key
self._hash = self._get_hash()
self._transform = transform
# if no dir is provided, the default dgl download dir is used.
if raw_dir is None:
self._raw_dir = get_download_dir()
else:
self._raw_dir = raw_dir
if save_dir is None:
self._save_dir = self._raw_dir
else:
self._save_dir = save_dir
self._load()
def download(self):
r"""Overwite to realize your own logic of downloading data.
It is recommended to download the to the :obj:`self.raw_dir`
folder. Can be ignored if the dataset is
already in :obj:`self.raw_dir`.
"""
pass
def save(self):
r"""Overwite to realize your own logic of
saving the processed dataset into files.
It is recommended to use ``dgl.data.utils.save_graphs``
to save dgl graph into files and use
``dgl.data.utils.save_info`` to save extra
information into files.
"""
pass
def load(self):
r"""Overwite to realize your own logic of
loading the saved dataset from files.
It is recommended to use ``dgl.data.utils.load_graphs``
to load dgl graph from files and use
``dgl.data.utils.load_info`` to load extra information
into python dict object.
"""
pass
@abc.abstractmethod
def process(self):
r"""Overwrite to realize your own logic of processing the input data."""
pass
def has_cache(self):
r"""Overwrite to realize your own logic of
deciding whether there exists a cached dataset.
By default False.
"""
return False
@retry_method_with_fix(download)
def _download(self):
"""Download dataset by calling ``self.download()``
if the dataset does not exists under ``self.raw_path``.
By default ``self.raw_path = os.path.join(self.raw_dir, self.name)``
One can overwrite ``raw_path()`` function to change the path.
"""
if os.path.exists(self.raw_path): # pragma: no cover
return
makedirs(self.raw_dir)
self.download()
def _load(self):
"""Entry point from __init__ to load the dataset.
If cache exists:
- Load the dataset from saved dgl graph and information files.
- If loadin process fails, re-download and process the dataset.
else:
- Download the dataset if needed.
- Process the dataset and build the dgl graph.
- Save the processed dataset into files.
"""
load_flag = not self._force_reload and self.has_cache()
if load_flag:
try:
self.load()
if self.verbose:
print("Done loading data from cached files.")
except KeyboardInterrupt:
raise
except:
load_flag = False
if self.verbose:
print(traceback.format_exc())
print("Loading from cache failed, re-processing.")
if not load_flag:
self._download()
self.process()
self.save()
if self.verbose:
print("Done saving data into cached files.")
def _get_hash(self):
"""Compute the hash of the input tuple
Example
-------
Assume `self._hash_key = (10, False, True)`
>>> hash_value = self._get_hash()
>>> hash_value
'a770b222'
"""
hash_func = hashlib.sha1()
hash_func.update(str(self._hash_key).encode("utf-8"))
return hash_func.hexdigest()[:8]
def _get_hash_url_suffix(self):
"""Get the suffix based on the hash value of the url."""
if self._url is None:
return ""
else:
hash_func = hashlib.sha1()
hash_func.update(str(self._url).encode("utf-8"))
return "_" + hash_func.hexdigest()[:8]
@property
def url(self):
r"""Get url to download the raw dataset."""
return self._url
@property
def name(self):
r"""Name of the dataset."""
return self._name
@property
def raw_dir(self):
r"""Raw file directory contains the input data folder."""
return self._raw_dir
@property
def raw_path(self):
r"""Directory contains the input data files.
By default raw_path = os.path.join(self.raw_dir, self.name)
"""
return os.path.join(
self.raw_dir, self.name + self._get_hash_url_suffix()
)
@property
def save_dir(self):
r"""Directory to save the processed dataset."""
return self._save_dir
@property
def save_path(self):
r"""Path to save the processed dataset."""
return os.path.join(
self.save_dir, self.name + self._get_hash_url_suffix()
)
@property
def verbose(self):
r"""Whether to print information."""
return self._verbose
@property
def hash(self):
r"""Hash value for the dataset and the setting."""
return self._hash
[docs] @abc.abstractmethod
def __getitem__(self, idx):
r"""Gets the data object at index."""
pass
[docs] @abc.abstractmethod
def __len__(self):
r"""The number of examples in the dataset."""
pass
def __repr__(self):
return (
f'Dataset("{self.name}", num_graphs={len(self)},'
+ f" save_path={self.save_path})"
)
class DGLBuiltinDataset(DGLDataset):
r"""The Basic DGL Builtin Dataset.
Parameters
----------
name : str
Name of the dataset.
url : str
Url to download the raw dataset.
raw_dir : str
Specifying the directory that will store the
downloaded data or the directory that
already stores the input data.
Default: ~/.dgl/
hash_key : tuple
A tuple of values as the input for the hash function.
Users can distinguish instances (and their caches on the disk)
from the same dataset class by comparing the hash values.
force_reload : bool
Whether to reload the dataset. Default: False
verbose : bool
Whether to print out progress information. Default: False
transform : callable, optional
A transform that takes in a :class:`~dgl.DGLGraph` object and returns
a transformed version. The :class:`~dgl.DGLGraph` object will be
transformed before every access.
"""
def __init__(
self,
name,
url,
raw_dir=None,
hash_key=(),
force_reload=False,
verbose=False,
transform=None,
):
super(DGLBuiltinDataset, self).__init__(
name,
url=url,
raw_dir=raw_dir,
save_dir=None,
hash_key=hash_key,
force_reload=force_reload,
verbose=verbose,
transform=transform,
)
def download(self):
r"""Automatically download data and extract it."""
if self.url is not None:
zip_file_path = os.path.join(self.raw_dir, self.name + ".zip")
download(self.url, path=zip_file_path)
extract_archive(zip_file_path, self.raw_path)