Source code for dgl.dataloading.dist_dataloader

"""Distributed dataloaders.
"""
import inspect
from ..distributed import DistDataLoader
# Still depends on the legacy NodeCollator...
from .._dataloading.dataloader import NodeCollator, EdgeCollator

def _remove_kwargs_dist(kwargs):
    if 'num_workers' in kwargs:
        del kwargs['num_workers']
    if 'pin_memory' in kwargs:
        del kwargs['pin_memory']
        print('Distributed DataLoaders do not support pin_memory.')
    return kwargs

[docs]class DistNodeDataLoader(DistDataLoader): """Sampled graph data loader over nodes for distributed graph storage. It wraps an iterable over a set of nodes, generating the list of message flow graphs (MFGs) as computation dependency of the said minibatch, on a distributed graph. All the arguments have the same meaning as the single-machine counterpart :class:`dgl.dataloading.DataLoader` except the first argument :attr:`g` which must be a :class:`dgl.distributed.DistGraph`. Parameters ---------- g : DistGraph The distributed graph. nids, graph_sampler, device, kwargs : See :class:`dgl.dataloading.DataLoader`. See also -------- dgl.dataloading.DataLoader """ def __init__(self, g, nids, graph_sampler, device=None, **kwargs): collator_kwargs = {} dataloader_kwargs = {} _collator_arglist = inspect.getfullargspec(NodeCollator).args for k, v in kwargs.items(): if k in _collator_arglist: collator_kwargs[k] = v else: dataloader_kwargs[k] = v if device is None: # for the distributed case default to the CPU device = 'cpu' assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.' # Distributed DataLoader currently does not support heterogeneous graphs # and does not copy features. Fallback to normal solution self.collator = NodeCollator(g, nids, graph_sampler, **collator_kwargs) _remove_kwargs_dist(dataloader_kwargs) super().__init__(self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs) self.device = device
[docs]class DistEdgeDataLoader(DistDataLoader): """Sampled graph data loader over edges for distributed graph storage. It wraps an iterable over a set of edges, generating the list of message flow graphs (MFGs) as computation dependency of the said minibatch for edge classification, edge regression, and link prediction, on a distributed graph. All the arguments have the same meaning as the single-machine counterpart :class:`dgl.dataloading.EdgeDataLoader` except the first argument :attr:`g` which must be a :class:`dgl.distributed.DistGraph`. Parameters ---------- g : DistGraph The distributed graph. eids, graph_sampler, device, kwargs : See :class:`dgl.dataloading.EdgeDataLoader`. See also -------- dgl.dataloading.EdgeDataLoader """ def __init__(self, g, eids, graph_sampler, device=None, **kwargs): collator_kwargs = {} dataloader_kwargs = {} _collator_arglist = inspect.getfullargspec(EdgeCollator).args for k, v in kwargs.items(): if k in _collator_arglist: collator_kwargs[k] = v else: dataloader_kwargs[k] = v if device is None: # for the distributed case default to the CPU device = 'cpu' assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.' # Distributed DataLoader currently does not support heterogeneous graphs # and does not copy features. Fallback to normal solution self.collator = EdgeCollator(g, eids, graph_sampler, **collator_kwargs) _remove_kwargs_dist(dataloader_kwargs) super().__init__(self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs) self.device = device