"""Subgraph samplers"""
from collections import defaultdict
from typing import Dict
from torch.utils.data import functional_datapipe
from .base import etype_str_to_tuple
from .internal import unique_and_compact
from .minibatch_transformer import MiniBatchTransformer
__all__ = [
"SubgraphSampler",
]
[docs]@functional_datapipe("sample_subgraph")
class SubgraphSampler(MiniBatchTransformer):
"""A subgraph sampler used to sample a subgraph from a given set of nodes
from a larger graph.
Functional name: :obj:`sample_subgraph`.
This class is the base class of all subgraph samplers. Any subclass of
SubgraphSampler should implement the :meth:`sample_subgraphs` method.
Parameters
----------
datapipe : DataPipe
The datapipe.
"""
def __init__(
self,
datapipe,
):
super().__init__(datapipe, self._sample)
def _sample(self, minibatch):
if minibatch.node_pairs is not None:
(
seeds,
minibatch.compacted_node_pairs,
minibatch.compacted_negative_srcs,
minibatch.compacted_negative_dsts,
) = self._node_pairs_preprocess(minibatch)
elif minibatch.seed_nodes is not None:
seeds = minibatch.seed_nodes
else:
raise ValueError(
f"Invalid minibatch {minibatch}: Either `node_pairs` or "
"`seed_nodes` should have a value."
)
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self.sample_subgraphs(seeds)
return minibatch
def _node_pairs_preprocess(self, minibatch):
node_pairs = minibatch.node_pairs
neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts
has_neg_src = neg_src is not None
has_neg_dst = neg_dst is not None
is_heterogeneous = isinstance(node_pairs, Dict)
if is_heterogeneous:
has_neg_src = has_neg_src and all(
item is not None for item in neg_src.values()
)
has_neg_dst = has_neg_dst and all(
item is not None for item in neg_dst.values()
)
# Collect nodes from all types of input.
nodes = defaultdict(list)
for etype, (src, dst) in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
nodes[src_type].append(src)
nodes[dst_type].append(dst)
if has_neg_src:
for etype, src in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
nodes[src_type].append(src.view(-1))
if has_neg_dst:
for etype, dst in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
nodes[dst_type].append(dst.view(-1))
# Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes)
(
compacted_node_pairs,
compacted_negative_srcs,
compacted_negative_dsts,
) = ({}, {}, {})
# Map back in same order as collect.
for etype, _ in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_node_pairs[etype] = (src, dst)
if has_neg_src:
for etype, _ in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
compacted_negative_srcs[etype] = compacted[src_type].pop(0)
compacted_negative_srcs[etype] = compacted_negative_srcs[
etype
].view(neg_src[etype].shape)
if has_neg_dst:
for etype, _ in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
compacted_negative_dsts[etype] = compacted[dst_type].pop(0)
compacted_negative_dsts[etype] = compacted_negative_dsts[
etype
].view(neg_dst[etype].shape)
else:
# Collect nodes from all types of input.
nodes = list(node_pairs)
if has_neg_src:
nodes.append(neg_src.view(-1))
if has_neg_dst:
nodes.append(neg_dst.view(-1))
# Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes)
# Map back in same order as collect.
compacted_node_pairs = tuple(compacted[:2])
compacted = compacted[2:]
if has_neg_src:
compacted_negative_srcs = compacted.pop(0)
# Since we need to calculate the neg_ratio according to the
# compacted_negatvie_srcs shape, we need to reshape it back.
compacted_negative_srcs = compacted_negative_srcs.view(
neg_src.shape
)
if has_neg_dst:
compacted_negative_dsts = compacted.pop(0)
# Same as above.
compacted_negative_dsts = compacted_negative_dsts.view(
neg_dst.shape
)
return (
seeds,
compacted_node_pairs,
compacted_negative_srcs if has_neg_src else None,
compacted_negative_dsts if has_neg_dst else None,
)
[docs] def sample_subgraphs(self, seeds, seeds_timestamp=None):
"""Sample subgraphs from the given seeds.
Any subclass of SubgraphSampler should implement this method.
Parameters
----------
seeds : Union[torch.Tensor, Dict[str, torch.Tensor]]
The seed nodes.
Returns
-------
Union[torch.Tensor, Dict[str, torch.Tensor]]
The input nodes.
List[SampledSubgraph]
The sampled subgraphs.
Examples
--------
>>> @functional_datapipe("my_sample_subgraph")
>>> class MySubgraphSampler(SubgraphSampler):
>>> def __init__(self, datapipe, graph, fanouts):
>>> super().__init__(datapipe)
>>> self.graph = graph
>>> self.fanouts = fanouts
>>> def sample_subgraphs(self, seeds):
>>> # Sample subgraphs from the given seeds.
>>> subgraphs = []
>>> subgraphs_nodes = []
>>> for fanout in reversed(self.fanouts):
>>> subgraph = self.graph.sample_neighbors(seeds, fanout)
>>> subgraphs.insert(0, subgraph)
>>> subgraphs_nodes.append(subgraph.nodes)
>>> seeds = subgraph.nodes
>>> subgraphs_nodes = torch.unique(torch.cat(subgraphs_nodes))
>>> return subgraphs_nodes, subgraphs
"""
raise NotImplementedError