"""Dataset utilities."""
from __future__ import absolute_import
import errno
import hashlib
import os
import pickle
import sys
import warnings
import numpy as np
import requests
from .. import backend as F
from .graph_serialize import load_graphs, load_labels, save_graphs
from .tensor_serialize import load_tensors, save_tensors
__all__ = [
"loadtxt",
"download",
"check_sha1",
"extract_archive",
"get_download_dir",
"Subset",
"split_dataset",
"save_graphs",
"load_graphs",
"load_labels",
"save_tensors",
"load_tensors",
"add_nodepred_split",
]
def loadtxt(path, delimiter, dtype=None):
try:
import pandas as pd
df = pd.read_csv(path, delimiter=delimiter, header=None)
return df.values
except ImportError:
warnings.warn(
"Pandas is not installed, now using numpy.loadtxt to load data, "
"which could be extremely slow. Accelerate by installing pandas"
)
return np.loadtxt(path, delimiter=delimiter)
def _get_dgl_url(file_url):
"""Get DGL online url for download."""
dgl_repo_url = "https://data.dgl.ai/"
repo_url = os.environ.get("DGL_REPO", dgl_repo_url)
if repo_url[-1] != "/":
repo_url = repo_url + "/"
return repo_url + file_url
[docs]def split_dataset(dataset, frac_list=None, shuffle=False, random_state=None):
"""Split dataset into training, validation and test set.
Parameters
----------
dataset
We assume ``len(dataset)`` gives the number of datapoints and ``dataset[i]``
gives the ith datapoint.
frac_list : list or None, optional
A list of length 3 containing the fraction to use for training,
validation and test. If None, we will use [0.8, 0.1, 0.1].
shuffle : bool, optional
By default we perform a consecutive split of the dataset. If True,
we will first randomly shuffle the dataset.
random_state : None, int or array_like, optional
Random seed used to initialize the pseudo-random number generator.
Can be any integer between 0 and 2**32 - 1 inclusive, an array
(or other sequence) of such integers, or None (the default).
If seed is None, then RandomState will try to read data from /dev/urandom
(or the Windows analogue) if available or seed from the clock otherwise.
Returns
-------
list of length 3
Subsets for training, validation and test.
"""
from itertools import accumulate
if frac_list is None:
frac_list = [0.8, 0.1, 0.1]
frac_list = np.asarray(frac_list)
assert np.allclose(
np.sum(frac_list), 1.0
), "Expect frac_list sum to 1, got {:.4f}".format(np.sum(frac_list))
num_data = len(dataset)
lengths = (num_data * frac_list).astype(int)
lengths[-1] = num_data - np.sum(lengths[:-1])
if shuffle:
indices = np.random.RandomState(seed=random_state).permutation(num_data)
else:
indices = np.arange(num_data)
return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(accumulate(lengths), lengths)
]
[docs]def download(
url,
path=None,
overwrite=True,
sha1_hash=None,
retries=5,
verify_ssl=True,
log=True,
):
"""Download a given URL.
Codes borrowed from mxnet/gluon/utils.py
Parameters
----------
url : str
URL to download.
path : str, optional
Destination path to store downloaded file. By default stores to the
current directory with the same name as in url.
overwrite : bool, optional
Whether to overwrite the destination file if it already exists.
By default always overwrites the downloaded file.
sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match.
retries : integer, default 5
The number of times to attempt downloading in case of failure or non 200 return codes.
verify_ssl : bool, default True
Verify SSL certificates.
log : bool, default True
Whether to print the progress for download
Returns
-------
str
The file path of the downloaded file.
"""
if path is None:
fname = url.split("/")[-1]
# Empty filenames are invalid
assert fname, (
"Can't construct file-name from this URL. "
"Please set the `path` option manually."
)
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
fname = os.path.join(path, url.split("/")[-1])
else:
fname = path
assert retries >= 0, "Number of retries should be at least 0"
if not verify_ssl:
warnings.warn(
"Unverified HTTPS request is being made (verify_ssl=False). "
"Adding certificate verification is strongly advised."
)
if (
overwrite
or not os.path.exists(fname)
or (sha1_hash and not check_sha1(fname, sha1_hash))
):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
os.makedirs(dirname)
while retries + 1 > 0:
# Disable pyling too broad Exception
# pylint: disable=W0703
try:
if log:
print("Downloading %s from %s..." % (fname, url))
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s" % url)
with open(fname, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
if sha1_hash and not check_sha1(fname, sha1_hash):
raise UserWarning(
"File {} is downloaded but the content hash does not match."
" The repo may be outdated or download may be incomplete. "
'If the "repo_url" is overridden, consider switching to '
"the default repo.".format(fname)
)
break
except Exception as e:
retries -= 1
if retries <= 0:
raise e
else:
if log:
print(
"download failed, retrying, {} attempt{} left".format(
retries, "s" if retries > 1 else ""
)
)
return fname
[docs]def check_sha1(filename, sha1_hash):
"""Check whether the sha1 hash of the file content matches the expected hash.
Codes borrowed from mxnet/gluon/utils.py
Parameters
----------
filename : str
Path to the file.
sha1_hash : str
Expected sha1 hash in hexadecimal digits.
Returns
-------
bool
Whether the file content matches the expected hash.
"""
sha1 = hashlib.sha1()
with open(filename, "rb") as f:
while True:
data = f.read(1048576)
if not data:
break
sha1.update(data)
return sha1.hexdigest() == sha1_hash
[docs]def get_download_dir():
"""Get the absolute path to the download directory.
Returns
-------
dirname : str
Path to the download directory
"""
default_dir = os.path.join(os.path.expanduser("~"), ".dgl")
dirname = os.environ.get("DGL_DOWNLOAD_DIR", default_dir)
if not os.path.exists(dirname):
os.makedirs(dirname)
return dirname
def makedirs(path):
try:
os.makedirs(os.path.expanduser(os.path.normpath(path)))
except OSError as e:
if e.errno != errno.EEXIST and os.path.isdir(path):
raise e
[docs]def save_info(path, info):
"""Save dataset related information into disk.
Parameters
----------
path : str
File to save information.
info : dict
A python dict storing information to save on disk.
"""
with open(path, "wb") as pf:
pickle.dump(info, pf)
[docs]def load_info(path):
"""Load dataset related information from disk.
Parameters
----------
path : str
File to load information from.
Returns
-------
info : dict
A python dict storing information loaded from disk.
"""
with open(path, "rb") as pf:
info = pickle.load(pf)
return info
def deprecate_property(old, new):
warnings.warn(
"Property {} will be deprecated, please use {} instead.".format(
old, new
)
)
def deprecate_function(old, new):
warnings.warn(
"Function {} will be deprecated, please use {} instead.".format(
old, new
)
)
def deprecate_class(old, new):
warnings.warn(
"Class {} will be deprecated, please use {} instead.".format(old, new)
)
def idx2mask(idx, len):
"""Create mask."""
mask = np.zeros(len)
mask[idx] = 1
return mask
def generate_mask_tensor(mask):
"""Generate mask tensor according to different backend
For torch and tensorflow, it will create a bool tensor
For mxnet, it will create a float tensor
Parameters
----------
mask: numpy ndarray
input mask tensor
"""
assert isinstance(mask, np.ndarray), (
"input for generate_mask_tensor" "should be an numpy ndarray"
)
if F.backend_name == "mxnet":
return F.tensor(mask, dtype=F.data_type_dict["float32"])
else:
return F.tensor(mask, dtype=F.data_type_dict["bool"])
[docs]class Subset(object):
"""Subset of a dataset at specified indices
Code adapted from PyTorch.
Parameters
----------
dataset
dataset[i] should return the ith datapoint
indices : list
List of datapoint indices to construct the subset
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices
[docs] def __getitem__(self, item):
"""Get the datapoint indexed by item
Returns
-------
tuple
datapoint
"""
return self.dataset[self.indices[item]]
[docs] def __len__(self):
"""Get subset size
Returns
-------
int
Number of datapoints in the subset
"""
return len(self.indices)
[docs]def add_nodepred_split(dataset, ratio, ntype=None):
"""Split the given dataset into training, validation and test sets for
transductive node predction task.
It adds three node mask arrays ``'train_mask'``, ``'val_mask'`` and ``'test_mask'``,
to each graph in the dataset. Each sample in the dataset thus must be a :class:`DGLGraph`.
Fix the random seed of NumPy to make the result deterministic::
numpy.random.seed(42)
Parameters
----------
dataset : DGLDataset
The dataset to modify.
ratio : (float, float, float)
Split ratios for training, validation and test sets. Must sum to one.
ntype : str, optional
The node type to add mask for.
Examples
--------
>>> dataset = dgl.data.AmazonCoBuyComputerDataset()
>>> print('train_mask' in dataset[0].ndata)
False
>>> dgl.data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
>>> print('train_mask' in dataset[0].ndata)
True
"""
if len(ratio) != 3:
raise ValueError(
f"Split ratio must be a float triplet but got {ratio}."
)
for i in range(len(dataset)):
g = dataset[i]
n = g.num_nodes(ntype)
idx = np.arange(0, n)
np.random.shuffle(idx)
n_train, n_val, n_test = (
int(n * ratio[0]),
int(n * ratio[1]),
int(n * ratio[2]),
)
train_mask = generate_mask_tensor(idx2mask(idx[:n_train], n))
val_mask = generate_mask_tensor(
idx2mask(idx[n_train : n_train + n_val], n)
)
test_mask = generate_mask_tensor(idx2mask(idx[n_train + n_val :], n))
g.nodes[ntype].data["train_mask"] = train_mask
g.nodes[ntype].data["val_mask"] = val_mask
g.nodes[ntype].data["test_mask"] = test_mask