Source code for dgl.data.icews18

"""ICEWS18 dataset for temporal graph"""
import numpy as np
import os

from .dgl_dataset import DGLBuiltinDataset
from .utils import loadtxt, _get_dgl_url, save_graphs, load_graphs
from ..convert import graph as dgl_graph
from .. import backend as F


[docs]class ICEWS18Dataset(DGLBuiltinDataset): r""" ICEWS18 dataset for temporal graph Integrated Crisis Early Warning System (ICEWS18) Event data consists of coded interactions between socio-political actors (i.e., cooperative or hostile actions between individuals, groups, sectors and nation states). This Dataset consists of events from 1/1/2018 to 10/31/2018 (24 hours time granularity). Reference: - `Recurrent Event Network for Reasoning over Temporal Knowledge Graphs <https://arxiv.org/abs/1904.05530>`_ - `ICEWS Coded Event Data <https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/28075>`_ Statistics: - Train examples: 240 - Valid examples: 30 - Test examples: 34 - Nodes per graph: 23033 Parameters ---------- mode: str Load train/valid/test data. Has to be one of ['train', 'valid', 'test'] raw_dir : str Raw file directory to download/contains the input data directory. Default: ~/.dgl/ force_reload : bool Whether to reload the dataset. Default: False verbose: bool Whether to print out progress information. Default: True. Attributes ------- is_temporal : bool Is the dataset contains temporal graphs Examples -------- >>> # get train, valid, test set >>> train_data = ICEWS18Dataset() >>> valid_data = ICEWS18Dataset(mode='valid') >>> test_data = ICEWS18Dataset(mode='test') >>> >>> train_size = len(train_data) >>> for g in train_data: .... e_feat = g.edata['rel_type'] .... # your code here .... >>> """ def __init__(self, mode='train', raw_dir=None, force_reload=False, verbose=False): mode = mode.lower() assert mode in ['train', 'valid', 'test'], "Mode not valid" self.mode = mode _url = _get_dgl_url('dataset/icews18.zip') super(ICEWS18Dataset, self).__init__(name='ICEWS18', url=_url, raw_dir=raw_dir, force_reload=force_reload, verbose=verbose) def process(self): data = loadtxt(os.path.join(self.save_path, '{}.txt'.format(self.mode)), delimiter='\t').astype(np.int64) num_nodes = 23033 # The source code is not released, but the paper indicates there're # totally 137 samples. The cutoff below has exactly 137 samples. time_index = np.floor(data[:, 3] / 24).astype(np.int64) start_time = time_index[time_index != -1].min() end_time = time_index.max() self._graphs = [] for i in range(start_time, end_time + 1): row_mask = time_index <= i edges = data[row_mask][:, [0, 2]] rate = data[row_mask][:, 1] g = dgl_graph((edges[:, 0], edges[:, 1])) g.edata['rel_type'] = F.tensor(rate.reshape(-1, 1), dtype=F.data_type_dict['int64']) self._graphs.append(g) def has_cache(self): graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode)) return os.path.exists(graph_path) def save(self): graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode)) save_graphs(graph_path, self._graphs) def load(self): graph_path = os.path.join(self.save_path, '{}_dgl_graph.bin'.format(self.mode)) self._graphs = load_graphs(graph_path)[0]
[docs] def __getitem__(self, idx): r""" Get graph by index Parameters ---------- idx : int Item index Returns ------- :class:`dgl.DGLGraph` The graph contains: - ``edata['rel_type']``: edge type """ return self._graphs[idx]
[docs] def __len__(self): r"""Number of graphs in the dataset. Return ------- int """ return len(self._graphs)
@property def is_temporal(self): r"""Is the dataset contains temporal graphs Returns ------- bool """ return True
ICEWS18 = ICEWS18Dataset