"""In-subgraph sampler for GraphBolt."""
from torch.utils.data import functional_datapipe
from ..internal import unique_and_compact_csc_formats
from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import SampledSubgraphImpl
__all__ = ["InSubgraphSampler"]
[docs]@functional_datapipe("sample_in_subgraph")
class InSubgraphSampler(SubgraphSampler):
"""Sample the subgraph induced on the inbound edges of the given nodes.
Functional name: :obj:`sample_in_subgraph`.
In-subgraph sampler is responsible for sampling a subgraph from given data,
returning an induced subgraph along with compacted information.
Parameters
----------
datapipe : DataPipe
The datapipe.
graph : FusedCSCSamplingGraph
The graph on which to perform in_subgraph sampling.
Examples
-------
>>> import dgl.graphbolt as gb
>>> import torch
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12, 14])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 5, 1, 2, 0, 3, 5, 1, 4])
>>> graph = gb.fused_csc_sampling_graph(indptr, indices)
>>> item_set = gb.ItemSet(len(indptr) - 1, names="seed_nodes")
>>> item_sampler = gb.ItemSampler(item_set, batch_size=2)
>>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
>>> for _, data in enumerate(insubgraph_sampler):
... print(data.sampled_subgraphs[0].sampled_csc)
... print(data.sampled_subgraphs[0].original_row_node_ids)
... print(data.sampled_subgraphs[0].original_column_node_ids)
CSCFormatBase(indptr=tensor([0, 3, 5]),
indices=tensor([0, 1, 2, 3, 4]),
)
tensor([0, 1, 4, 2, 3])
tensor([0, 1])
CSCFormatBase(indptr=tensor([0, 2, 4]),
indices=tensor([2, 3, 4, 0]),
)
tensor([2, 3, 0, 5, 1])
tensor([2, 3])
CSCFormatBase(indptr=tensor([0, 3, 5]),
indices=tensor([2, 3, 1, 4, 0]),
)
tensor([4, 5, 0, 3, 1])
tensor([4, 5])
"""
def __init__(
self,
datapipe,
graph,
):
super().__init__(datapipe)
self.graph = graph
self.sampler = graph.in_subgraph
[docs] def sample_subgraphs(self, seeds, seeds_timestamp):
subgraph = self.sampler(seeds)
(
original_row_node_ids,
compacted_csc_formats,
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
sampled_csc=compacted_csc_formats,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
seeds = original_row_node_ids
return (seeds, [subgraph])