Source code for dgl.data.bitcoinotc

""" BitcoinOTC dataset for fraud detection """
import datetime
import gzip
import os
import shutil

import numpy as np

from .. import backend as F
from ..convert import graph as dgl_graph
from .dgl_dataset import DGLBuiltinDataset
from .utils import check_sha1, download, load_graphs, makedirs, save_graphs


[docs]class BitcoinOTCDataset(DGLBuiltinDataset): r"""BitcoinOTC dataset for fraud detection This is who-trusts-whom network of people who trade using Bitcoin on a platform called Bitcoin OTC. Since Bitcoin users are anonymous, there is a need to maintain a record of users' reputation to prevent transactions with fraudulent and risky users. Offical website: `<https://snap.stanford.edu/data/soc-sign-bitcoin-otc.html>`_ Bitcoin OTC dataset statistics: - Nodes: 5,881 - Edges: 35,592 - Range of edge weight: -10 to +10 - Percentage of positive edges: 89% Parameters ---------- raw_dir : str Raw file directory to download/contains the input data directory. Default: ~/.dgl/ force_reload : bool Whether to reload the dataset. Default: False verbose: bool Whether to print out progress information. Default: True. transform : callable, optional A transform that takes in a :class:`~dgl.DGLGraph` object and returns a transformed version. The :class:`~dgl.DGLGraph` object will be transformed before every access. Attributes ---------- graphs : list A list of DGLGraph objects is_temporal : bool Indicate whether the graphs are temporal graphs Raises ------ UserWarning If the raw data is changed in the remote server by the author. Examples -------- >>> dataset = BitcoinOTCDataset() >>> len(dataset) 136 >>> for g in dataset: .... # get edge feature .... edge_weights = g.edata['h'] .... # your code here >>> """ _url = "https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz" _sha1_str = "c14281f9e252de0bd0b5f1c6e2bae03123938641" def __init__( self, raw_dir=None, force_reload=False, verbose=False, transform=None ): super(BitcoinOTCDataset, self).__init__( name="bitcoinotc", url=self._url, raw_dir=raw_dir, force_reload=force_reload, verbose=verbose, transform=transform, ) def download(self): gz_file_path = os.path.join(self.raw_dir, self.name + ".csv.gz") download(self.url, path=gz_file_path) if not check_sha1(gz_file_path, self._sha1_str): raise UserWarning( "File {} is downloaded but the content hash does not match." "The repo may be outdated or download may be incomplete. " "Otherwise you can create an issue for it.".format( self.name + ".csv.gz" ) ) self._extract_gz(gz_file_path, self.raw_path) def process(self): filename = os.path.join(self.save_path, self.name + ".csv") data = np.loadtxt(filename, delimiter=",").astype(np.int64) data[:, 0:2] = data[:, 0:2] - data[:, 0:2].min() delta = datetime.timedelta(days=14).total_seconds() # The source code is not released, but the paper indicates there're # totally 137 samples. The cutoff below has exactly 137 samples. time_index = np.around((data[:, 3] - data[:, 3].min()) / delta).astype( np.int64 ) self._graphs = [] for i in range(time_index.max()): row_mask = time_index <= i edges = data[row_mask][:, 0:2] rate = data[row_mask][:, 2] g = dgl_graph((edges[:, 0], edges[:, 1])) g.edata["h"] = F.tensor( rate.reshape(-1, 1), dtype=F.data_type_dict["int64"] ) self._graphs.append(g) def has_cache(self): graph_path = os.path.join(self.save_path, "dgl_graph.bin") return os.path.exists(graph_path) def save(self): graph_path = os.path.join(self.save_path, "dgl_graph.bin") save_graphs(graph_path, self.graphs) def load(self): graph_path = os.path.join(self.save_path, "dgl_graph.bin") self._graphs = load_graphs(graph_path)[0] @property def graphs(self): return self._graphs
[docs] def __len__(self): r"""Number of graphs in the dataset. Return ------- int """ return len(self.graphs)
[docs] def __getitem__(self, item): r"""Get graph by index Parameters ---------- item : int Item index Returns ------- :class:`dgl.DGLGraph` The graph contains: - ``edata['h']`` : edge weights """ if self._transform is None: return self.graphs[item] else: return self._transform(self.graphs[item])
@property def is_temporal(self): r"""Are the graphs temporal graphs Returns ------- bool """ return True def _extract_gz(self, file, target_dir, overwrite=False): if os.path.exists(target_dir) and not overwrite: return print("Extracting file to {}".format(target_dir)) fname = os.path.basename(file) makedirs(target_dir) out_file_path = os.path.join(target_dir, fname[:-3]) with gzip.open(file, "rb") as f_in: with open(out_file_path, "wb") as f_out: shutil.copyfileobj(f_in, f_out)
BitcoinOTC = BitcoinOTCDataset