Source code for dgl.data.csv_dataset

import os

import numpy as np

from .. import backend as F
from ..base import DGLError
from .dgl_dataset import DGLDataset
from .utils import load_graphs, save_graphs, Subset


[docs]class CSVDataset(DGLDataset): """Dataset class that loads and parses graph data from CSV files. This class requires the following additional packages: - pyyaml >= 5.4.1 - pandas >= 1.1.5 - pydantic >= 1.9.0 The parsed graph and feature data will be cached for faster reloading. If the source CSV files are modified, please specify ``force_reload=True`` to re-parse from them. Parameters ---------- data_path : str Directory which contains 'meta.yaml' and CSV files force_reload : bool, optional Whether to reload the dataset. Default: False verbose: bool, optional Whether to print out progress information. Default: True. ndata_parser : dict[str, callable] or callable, optional Callable object which takes in the ``pandas.DataFrame`` object created from CSV file, parses node data and returns a dictionary of parsed data. If given a dictionary, the key is node type and the value is a callable object which is used to parse data of corresponding node type. If given a single callable object, such object is used to parse data of all node type data. Default: None. If None, a default data parser is applied which load data directly and tries to convert list into array. edata_parser : dict[(str, str, str), callable], or callable, optional Callable object which takes in the ``pandas.DataFrame`` object created from CSV file, parses edge data and returns a dictionary of parsed data. If given a dictionary, the key is edge type and the value is a callable object which is used to parse data of corresponding edge type. If given a single callable object, such object is used to parse data of all edge type data. Default: None. If None, a default data parser is applied which load data directly and tries to convert list into array. gdata_parser : callable, optional Callable object which takes in the ``pandas.DataFrame`` object created from CSV file, parses graph data and returns a dictionary of parsed data. Default: None. If None, a default data parser is applied which load data directly and tries to convert list into array. transform : callable, optional A transform that takes in a :class:`~dgl.DGLGraph` object and returns a transformed version. The :class:`~dgl.DGLGraph` object will be transformed before every access. Attributes ---------- graphs : :class:`dgl.DGLGraph` Graphs of the dataset data : dict any available graph-level data such as graph-level feature, labels. Examples -------- Please refer to :ref:`guide-data-pipeline-loadcsv`. """ META_YAML_NAME = "meta.yaml" def __init__( self, data_path, force_reload=False, verbose=True, ndata_parser=None, edata_parser=None, gdata_parser=None, transform=None, ): from .csv_dataset_base import ( DefaultDataParser, load_yaml_with_sanity_check, ) self.graphs = None self.data = None self.ndata_parser = {} if ndata_parser is None else ndata_parser self.edata_parser = {} if edata_parser is None else edata_parser self.gdata_parser = gdata_parser self.default_data_parser = DefaultDataParser() meta_yaml_path = os.path.join(data_path, CSVDataset.META_YAML_NAME) if not os.path.exists(meta_yaml_path): raise DGLError( "'{}' cannot be found under {}.".format( CSVDataset.META_YAML_NAME, data_path ) ) self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path) ds_name = self.meta_yaml.dataset_name super().__init__( ds_name, raw_dir=os.path.dirname(meta_yaml_path), force_reload=force_reload, verbose=verbose, transform=transform, ) def process(self): """Parse node/edge data from CSV files and construct DGL.Graphs""" from .csv_dataset_base import ( DGLGraphConstructor, EdgeData, GraphData, NodeData, ) meta_yaml = self.meta_yaml base_dir = self.raw_dir node_data = [] for meta_node in meta_yaml.node_data: if meta_node is None: continue ntype = meta_node.ntype data_parser = ( self.ndata_parser if callable(self.ndata_parser) else self.ndata_parser.get(ntype, self.default_data_parser) ) ndata = NodeData.load_from_csv( meta_node, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser, ) node_data.append(ndata) edge_data = [] for meta_edge in meta_yaml.edge_data: if meta_edge is None: continue etype = tuple(meta_edge.etype) data_parser = ( self.edata_parser if callable(self.edata_parser) else self.edata_parser.get(etype, self.default_data_parser) ) edata = EdgeData.load_from_csv( meta_edge, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser, ) edge_data.append(edata) graph_data = None if meta_yaml.graph_data is not None: meta_graph = meta_yaml.graph_data data_parser = ( self.default_data_parser if self.gdata_parser is None else self.gdata_parser ) graph_data = GraphData.load_from_csv( meta_graph, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser, ) # construct graphs self.graphs, self.data = DGLGraphConstructor.construct_graphs( node_data, edge_data, graph_data ) if len(self.data) == 1: self.labels = list(self.data.values())[0] def has_cache(self): graph_path = os.path.join(self.save_path, self.name + ".bin") if os.path.exists(graph_path): return True return False def save(self): if self.graphs is None: raise DGLError("No graphs available in dataset") graph_path = os.path.join(self.save_path, self.name + ".bin") save_graphs(graph_path, self.graphs, labels=self.data) def load(self): graph_path = os.path.join(self.save_path, self.name + ".bin") self.graphs, self.data = load_graphs(graph_path) if len(self.data) == 1: self.labels = list(self.data.values())[0]
[docs] def __getitem__(self, i): if F.is_tensor(i) and F.ndim(i) == 1: return Subset(self, F.copy_to(i, F.cpu())) if self._transform is None: g = self.graphs[i] else: g = self._transform(self.graphs[i]) if len(self.data) == 1: return g, self.labels[i] elif len(self.data) > 0: data = {k: v[i] for (k, v) in self.data.items()} return g, data else: return g
[docs] def __len__(self): return len(self.graphs)