Source code for dgl.graphbolt.dataloader

"""Graph Bolt DataLoaders"""

from concurrent.futures import ThreadPoolExecutor

import torch
import torch.utils.data
import torchdata.dataloader2.graph as dp_utils
import torchdata.datapipes as dp

from .base import CopyTo
from .feature_fetcher import FeatureFetcher
from .impl.neighbor_sampler import SamplePerLayer

from .internal import datapipe_graph_to_adjlist
from .item_sampler import ItemSampler


__all__ = [
    "DataLoader",
]


def _find_and_wrap_parent(datapipe_graph, target_datapipe, wrapper, **kwargs):
    """Find parent of target_datapipe and wrap it with ."""
    datapipes = dp_utils.find_dps(
        datapipe_graph,
        target_datapipe,
    )
    datapipe_adjlist = datapipe_graph_to_adjlist(datapipe_graph)
    for datapipe in datapipes:
        datapipe_id = id(datapipe)
        for parent_datapipe_id in datapipe_adjlist[datapipe_id][1]:
            parent_datapipe, _ = datapipe_adjlist[parent_datapipe_id]
            datapipe_graph = dp_utils.replace_dp(
                datapipe_graph,
                parent_datapipe,
                wrapper(parent_datapipe, **kwargs),
            )
    return datapipe_graph


class MultiprocessingWrapper(dp.iter.IterDataPipe):
    """Wraps a datapipe with multiprocessing.

    Parameters
    ----------
    datapipe : DataPipe
        The data pipeline.
    num_workers : int, optional
        The number of worker processes. Default is 0, meaning that there
        will be no multiprocessing.
    persistent_workers : bool, optional
        If True, the data loader will not shut down the worker processes after a
        dataset has been consumed once. This allows to maintain the workers
        instances alive.
    """

    def __init__(self, datapipe, num_workers=0, persistent_workers=True):
        self.datapipe = datapipe
        self.dataloader = torch.utils.data.DataLoader(
            datapipe,
            batch_size=None,
            num_workers=num_workers,
            persistent_workers=(num_workers > 0) and persistent_workers,
        )

    def __iter__(self):
        yield from self.dataloader


# There needs to be a single instance of the uva_stream, if it is created
# multiple times, it leads to multiple CUDA memory pools and memory leaks.
def _get_uva_stream():
    if not hasattr(_get_uva_stream, "stream"):
        _get_uva_stream.stream = torch.cuda.Stream(priority=-1)
    return _get_uva_stream.stream


[docs]class DataLoader(torch.utils.data.DataLoader): """Multiprocessing DataLoader. Iterates over the data pipeline with everything before feature fetching (i.e. :class:`dgl.graphbolt.FeatureFetcher`) in subprocesses, and everything after feature fetching in the main process. The datapipe is modified in-place as a result. When the copy_to operation is placed earlier in the data pipeline, the num_workers argument is required to be 0 as utilizing CUDA in multiple worker processes is not supported. Parameters ---------- datapipe : DataPipe The data pipeline. num_workers : int, optional Number of worker processes. Default is 0. persistent_workers : bool, optional If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers instances alive. overlap_feature_fetch : bool, optional If True, the data loader will overlap the UVA feature fetcher operations with the rest of operations by using an alternative CUDA stream. Default is True. overlap_graph_fetch : bool, optional If True, the data loader will overlap the UVA graph fetching operations with the rest of operations by using an alternative CUDA stream. Default is False. max_uva_threads : int, optional Limits the number of CUDA threads used for UVA copies so that the rest of the computations can run simultaneously with it. Setting it to a too high value will limit the amount of overlap while setting it too low may cause the PCI-e bandwidth to not get fully utilized. Manually tuned default is 6144, meaning around 3-4 Streaming Multiprocessors. """ def __init__( self, datapipe, num_workers=0, persistent_workers=True, overlap_feature_fetch=True, overlap_graph_fetch=False, max_uva_threads=6144, ): # Multiprocessing requires two modifications to the datapipe: # # 1. Insert a stage after ItemSampler to distribute the # minibatches evenly across processes. # 2. Cut the datapipe at FeatureFetcher, and wrap the inner datapipe # of the FeatureFetcher with a multiprocessing PyTorch DataLoader. datapipe = datapipe.mark_end() datapipe_graph = dp_utils.traverse_dps(datapipe) # (1) Insert minibatch distribution. # TODO(BarclayII): Currently I'm using sharding_filter() as a # concept demonstration. Later on minibatch distribution should be # merged into ItemSampler to maximize efficiency. item_samplers = dp_utils.find_dps( datapipe_graph, ItemSampler, ) for item_sampler in item_samplers: datapipe_graph = dp_utils.replace_dp( datapipe_graph, item_sampler, item_sampler.sharding_filter(), ) # (2) Cut datapipe at FeatureFetcher and wrap. datapipe_graph = _find_and_wrap_parent( datapipe_graph, FeatureFetcher, MultiprocessingWrapper, num_workers=num_workers, persistent_workers=persistent_workers, ) # (3) Overlap UVA feature fetching by buffering and using an alternative # stream. if ( overlap_feature_fetch and num_workers == 0 and torch.cuda.is_available() ): torch.ops.graphbolt.set_max_uva_threads(max_uva_threads) feature_fetchers = dp_utils.find_dps( datapipe_graph, FeatureFetcher, ) for feature_fetcher in feature_fetchers: feature_fetcher.stream = _get_uva_stream() datapipe_graph = dp_utils.replace_dp( datapipe_graph, feature_fetcher, feature_fetcher.buffer(1).wait(), ) if ( overlap_graph_fetch and num_workers == 0 and torch.cuda.is_available() ): torch.ops.graphbolt.set_max_uva_threads(max_uva_threads) samplers = dp_utils.find_dps( datapipe_graph, SamplePerLayer, ) executor = ThreadPoolExecutor(max_workers=1) for sampler in samplers: datapipe_graph = dp_utils.replace_dp( datapipe_graph, sampler, sampler.fetch_and_sample(_get_uva_stream(), executor, 1), ) # (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the # data pipeline up to the CopyTo operation to run in a separate thread. datapipe_graph = _find_and_wrap_parent( datapipe_graph, CopyTo, dp.iter.Prefetcher, buffer_size=2, ) # The stages after feature fetching is still done in the main process. # So we set num_workers to 0 here. super().__init__(datapipe, batch_size=None, num_workers=0)