Source code for dgl.graphbolt.subgraph_sampler

"""Subgraph samplers"""

from collections import defaultdict
from typing import Dict

import torch
from torch.utils.data import functional_datapipe

from .base import seed_type_str_to_ntypes
from .internal import compact_temporal_nodes, unique_and_compact
from .minibatch_transformer import MiniBatchTransformer

__all__ = [
    "SubgraphSampler",
]


[docs]@functional_datapipe("sample_subgraph") class SubgraphSampler(MiniBatchTransformer): """A subgraph sampler used to sample a subgraph from a given set of nodes from a larger graph. Functional name: :obj:`sample_subgraph`. This class is the base class of all subgraph samplers. Any subclass of SubgraphSampler should implement either the :meth:`sample_subgraphs` method or the :meth:`sampling_stages` method to define the fine-grained sampling stages to take advantage of optimizations provided by the GraphBolt DataLoader. Parameters ---------- datapipe : DataPipe The datapipe. args : Non-Keyword Arguments Arguments to be passed into sampling_stages. kwargs : Keyword Arguments Arguments to be passed into sampling_stages. """ def __init__( self, datapipe, *args, **kwargs, ): datapipe = datapipe.transform(self._preprocess) datapipe = self.sampling_stages(datapipe, *args, **kwargs) datapipe = datapipe.transform(self._postprocess) super().__init__(datapipe) @staticmethod def _postprocess(minibatch): delattr(minibatch, "_seed_nodes") delattr(minibatch, "_seeds_timestamp") return minibatch @staticmethod def _preprocess(minibatch): if minibatch.seeds is not None: ( seeds, seeds_timestamp, minibatch.compacted_seeds, ) = SubgraphSampler._seeds_preprocess(minibatch) else: raise ValueError( f"Invalid minibatch {minibatch}: `seeds` should have a value." ) minibatch._seed_nodes = seeds minibatch._seeds_timestamp = seeds_timestamp return minibatch def _sample(self, minibatch): ( minibatch.input_nodes, minibatch.sampled_subgraphs, ) = self.sample_subgraphs( minibatch._seed_nodes, minibatch._seeds_timestamp ) return minibatch
[docs] def sampling_stages(self, datapipe): """The sampling stages are defined here by chaining to the datapipe. The default implementation expects :meth:`sample_subgraphs` to be implemented. To define fine-grained stages, this method should be overridden. """ return datapipe.transform(self._sample)
@staticmethod def _seeds_preprocess(minibatch): """Preprocess `seeds` in a minibatch to construct `unique_seeds`, `node_timestamp` and `compacted_seeds` for further sampling. It optionally incorporates timestamps for temporal graphs, organizing and compacting seeds based on their types and timestamps. In heterogeneous graph, `seeds` with same node type will be unqiued together. Parameters ---------- minibatch: MiniBatch The minibatch. Returns ------- unique_seeds: torch.Tensor or Dict[str, torch.Tensor] A tensor or a dictionary of tensors representing the unique seeds. In heterogeneous graphs, seeds are returned for each node type. nodes_timestamp: None or a torch.Tensor or Dict[str, torch.Tensor] Containing timestamps for each seed. This is only returned if `minibatch` includes timestamps and the graph is temporal. compacted_seeds: torch.tensor or a Dict[str, torch.Tensor] Representation of compacted seeds corresponding to 'seeds', where all node ids inside are compacted. """ use_timestamp = hasattr(minibatch, "timestamp") seeds = minibatch.seeds is_heterogeneous = isinstance(seeds, Dict) if is_heterogeneous: # Collect nodes from all types of input. nodes = defaultdict(list) nodes_timestamp = None if use_timestamp: nodes_timestamp = defaultdict(list) for seed_type, typed_seeds in seeds.items(): # When typed_seeds is a one-dimensional tensor, it represents # seed nodes, which does not need to do unique and compact. if typed_seeds.ndim == 1: nodes_timestamp = ( minibatch.timestamp if hasattr(minibatch, "timestamp") else None ) return seeds, nodes_timestamp, None assert typed_seeds.ndim == 2, ( "Only tensor with shape 1*N and N*M is " + f"supported now, but got {typed_seeds.shape}." ) ntypes = seed_type_str_to_ntypes( seed_type, typed_seeds.shape[1] ) if use_timestamp: negative_ratio = ( typed_seeds.shape[0] // minibatch.timestamp[seed_type].shape[0] - 1 ) neg_timestamp = minibatch.timestamp[ seed_type ].repeat_interleave(negative_ratio) for i, ntype in enumerate(ntypes): nodes[ntype].append(typed_seeds[:, i]) if use_timestamp: nodes_timestamp[ntype].append( minibatch.timestamp[seed_type] ) nodes_timestamp[ntype].append(neg_timestamp) # Unique and compact the collected nodes. if use_timestamp: ( unique_seeds, nodes_timestamp, compacted, ) = compact_temporal_nodes(nodes, nodes_timestamp) else: unique_seeds, compacted = unique_and_compact(nodes) nodes_timestamp = None compacted_seeds = {} # Map back in same order as collect. for seed_type, typed_seeds in seeds.items(): ntypes = seed_type_str_to_ntypes( seed_type, typed_seeds.shape[1] ) compacted_seed = [] for ntype in ntypes: compacted_seed.append(compacted[ntype].pop(0)) compacted_seeds[seed_type] = ( torch.cat(compacted_seed).view(len(ntypes), -1).T ) else: # When seeds is a one-dimensional tensor, it represents seed nodes, # which does not need to do unique and compact. if seeds.ndim == 1: nodes_timestamp = ( minibatch.timestamp if hasattr(minibatch, "timestamp") else None ) return seeds, nodes_timestamp, None # Collect nodes from all types of input. nodes = [seeds.view(-1)] nodes_timestamp = None if use_timestamp: # Timestamp for source and destination nodes are the same. negative_ratio = ( seeds.shape[0] // minibatch.timestamp.shape[0] - 1 ) neg_timestamp = minibatch.timestamp.repeat_interleave( negative_ratio ) seeds_timestamp = torch.cat( (minibatch.timestamp, neg_timestamp) ) nodes_timestamp = [ seeds_timestamp for _ in range(seeds.shape[1]) ] # Unique and compact the collected nodes. if use_timestamp: ( unique_seeds, nodes_timestamp, compacted, ) = compact_temporal_nodes(nodes, nodes_timestamp) else: unique_seeds, compacted = unique_and_compact(nodes) nodes_timestamp = None # Map back in same order as collect. compacted_seeds = compacted[0].view(seeds.shape) return ( unique_seeds, nodes_timestamp, compacted_seeds, )
[docs] def sample_subgraphs(self, seeds, seeds_timestamp): """Sample subgraphs from the given seeds, possibly with temporal constraints. Any subclass of SubgraphSampler should implement this method. Parameters ---------- seeds : Union[torch.Tensor, Dict[str, torch.Tensor]] The seed nodes. seeds_timestamp : Union[torch.Tensor, Dict[str, torch.Tensor]] The timestamps of the seed nodes. If given, the sampled subgraphs should not contain any nodes or edges that are newer than the timestamps of the seed nodes. Default: None. Returns ------- Union[torch.Tensor, Dict[str, torch.Tensor]] The input nodes. List[SampledSubgraph] The sampled subgraphs. Examples -------- >>> @functional_datapipe("my_sample_subgraph") >>> class MySubgraphSampler(SubgraphSampler): >>> def __init__(self, datapipe, graph, fanouts): >>> super().__init__(datapipe) >>> self.graph = graph >>> self.fanouts = fanouts >>> def sample_subgraphs(self, seeds): >>> # Sample subgraphs from the given seeds. >>> subgraphs = [] >>> subgraphs_nodes = [] >>> for fanout in reversed(self.fanouts): >>> subgraph = self.graph.sample_neighbors(seeds, fanout) >>> subgraphs.insert(0, subgraph) >>> subgraphs_nodes.append(subgraph.nodes) >>> seeds = subgraph.nodes >>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes)) >>> return subgraphs_nodes, subgraphs """ raise NotImplementedError