ItemSampler

class dgl.graphbolt.ItemSampler(item_set: ~dgl.graphbolt.itemset.ItemSet | ~dgl.graphbolt.itemset.ItemSetDict, batch_size: int, minibatcher: ~typing.Callable | None = <function minibatcher_default>, drop_last: bool | None = False, shuffle: bool | None = False, use_indexing: bool | None = True, buffer_size: int | None = -1)[source]

Bases: IterDataPipe

A sampler to iterate over input items and create subsets.

Input items could be node IDs, node pairs with or without labels, node pairs with negative sources/destinations, DGLGraphs and heterogeneous counterparts.

Note: This class ItemSampler is not decorated with torchdata.datapipes.functional_datapipe on purpose. This indicates it does not support function-like call. But any iterable datapipes from torchdata can be further appended.

Parameters:
  • item_set (Union[ItemSet, ItemSetDict]) – Data to be sampled.

  • batch_size (int) – The size of each batch.

  • minibatcher (Optional[Callable]) – A callable that takes in a list of items and returns a MiniBatch.

  • drop_last (bool) – Option to drop the last batch if it’s not full.

  • shuffle (bool) – Option to shuffle before sample.

  • use_indexing (bool) – Option to use indexing to slice items from the item set. This is an optimization to avoid time-consuming iteration over the item set. If the item set does not support indexing, this option will be disabled automatically. If the item set supports indexing but the user wants to disable it, this option can be set to False. By default, it is set to True.

  • buffer_size (int) – The size of the buffer to store items sliced from the ItemSet or ItemSetDict. By default, it is set to -1, which means the buffer size will be set as the total number of items in the item set if indexing is supported. If indexing is not supported, it is set to 10 * batch size. If the item set is too large, it is recommended to set a smaller buffer size to avoid out of memory error. As items are shuffled within each buffer, a smaller buffer size may incur less randomness and such less randomness can further affect the training performance such as convergence speed and accuracy. Therefore, it is recommended to set a larger buffer size if possible.

Examples

  1. Node IDs.

>>> import torch
>>> from dgl import graphbolt as gb
>>> item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
>>> item_sampler = gb.ItemSampler(
...     item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=tensor([0, 1, 2, 3]), node_pairs=None, labels=None,
    negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
    input_nodes=None, node_features=None, edge_features=None,
    compacted_node_pairs=None, compacted_negative_srcs=None,
    compacted_negative_dsts=None)
  1. Node pairs.

>>> item_set = gb.ItemSet(torch.arange(0, 20).reshape(-1, 2),
...     names="node_pairs")
>>> item_sampler = gb.ItemSampler(
...     item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
    node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
    labels=None, negative_srcs=None, negative_dsts=None,
    sampled_subgraphs=None, input_nodes=None, node_features=None,
    edge_features=None, compacted_node_pairs=None,
    compacted_negative_srcs=None, compacted_negative_dsts=None)
  1. Node pairs and labels.

>>> item_set = gb.ItemSet(
...     (torch.arange(0, 20).reshape(-1, 2), torch.arange(10, 20)),
...     names=("node_pairs", "labels")
... )
>>> item_sampler = gb.ItemSampler(
...     item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
    node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
    labels=tensor([10, 11, 12, 13]), negative_srcs=None,
    negative_dsts=None, sampled_subgraphs=None, input_nodes=None,
    node_features=None, edge_features=None, compacted_node_pairs=None,
    compacted_negative_srcs=None, compacted_negative_dsts=None)
  1. Node pairs and negative destinations.

>>> node_pairs = torch.arange(0, 20).reshape(-1, 2)
>>> negative_dsts = torch.arange(10, 30).reshape(-1, 2)
>>> item_set = gb.ItemSet((node_pairs, negative_dsts), names=("node_pairs",
...     "negative_dsts"))
>>> item_sampler = gb.ItemSampler(
...     item_set, batch_size=4, shuffle=False, drop_last=False
... )
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
    node_pairs=(tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7])),
    labels=None, negative_srcs=None,
    negative_dsts=tensor([[10, 11],
    [12, 13],
    [14, 15],
    [16, 17]]), sampled_subgraphs=None, input_nodes=None,
    node_features=None, edge_features=None, compacted_node_pairs=None,
    compacted_negative_srcs=None, compacted_negative_dsts=None)
  1. DGLGraphs.

>>> import dgl
>>> graphs = [ dgl.rand_graph(10, 20) for _ in range(5) ]
>>> item_set = gb.ItemSet(graphs)
>>> item_sampler = gb.ItemSampler(item_set, 3)
>>> list(item_sampler)
[Graph(num_nodes=30, num_edges=60,
  ndata_schemes={}
  edata_schemes={}),
 Graph(num_nodes=20, num_edges=40,
  ndata_schemes={}
  edata_schemes={})]

6. Further process batches with other datapipes such as torchdata.datapipes.iter.Mapper.

>>> item_set = gb.ItemSet(torch.arange(0, 10))
>>> data_pipe = gb.ItemSampler(item_set, 4)
>>> def add_one(batch):
...     return batch + 1
>>> data_pipe = data_pipe.map(add_one)
>>> list(data_pipe)
[tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10])]
  1. Heterogeneous node IDs.

>>> ids = {
...     "user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"),
...     "item": gb.ItemSet(torch.arange(0, 6), names="seed_nodes"),
... }
>>> item_set = gb.ItemSetDict(ids)
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes={'user': tensor([0, 1, 2, 3])}, node_pairs=None,
labels=None, negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
input_nodes=None, node_features=None, edge_features=None,
compacted_node_pairs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None)
  1. Heterogeneous node pairs.

>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({
...     "user:like:item": gb.ItemSet(
...         node_pairs_like, names="node_pairs"),
...     "user:follow:user": gb.ItemSet(
...         node_pairs_follow, names="node_pairs"),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
    node_pairs={'user:like:item':
        (tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
    labels=None, negative_srcs=None, negative_dsts=None,
    sampled_subgraphs=None, input_nodes=None, node_features=None,
    edge_features=None, compacted_node_pairs=None,
    compacted_negative_srcs=None, compacted_negative_dsts=None)
  1. Heterogeneous node pairs and labels.

>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> labels_like = torch.arange(0, 10)
>>> node_pairs_follow = torch.arange(10, 20).reshape(-1, 2)
>>> labels_follow = torch.arange(10, 20)
>>> item_set = gb.ItemSetDict({
...     "user:like:item": gb.ItemSet((node_pairs_like, labels_like),
...         names=("node_pairs", "labels")),
...     "user:follow:user": gb.ItemSet((node_pairs_follow, labels_follow),
...         names=("node_pairs", "labels")),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
    node_pairs={'user:like:item':
        (tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
    labels={'user:like:item': tensor([0, 1, 2, 3])},
    negative_srcs=None, negative_dsts=None, sampled_subgraphs=None,
    input_nodes=None, node_features=None, edge_features=None,
    compacted_node_pairs=None, compacted_negative_srcs=None,
    compacted_negative_dsts=None)
  1. Heterogeneous node pairs and negative destinations.

>>> node_pairs_like = torch.arange(0, 10).reshape(-1, 2)
>>> negative_dsts_like = torch.arange(10, 20).reshape(-1, 2)
>>> node_pairs_follow = torch.arange(20, 30).reshape(-1, 2)
>>> negative_dsts_follow = torch.arange(30, 40).reshape(-1, 2)
>>> item_set = gb.ItemSetDict({
...     "user:like:item": gb.ItemSet((node_pairs_like, negative_dsts_like),
...         names=("node_pairs", "negative_dsts")),
...     "user:follow:user": gb.ItemSet((node_pairs_follow,
...         negative_dsts_follow), names=("node_pairs", "negative_dsts")),
... })
>>> item_sampler = gb.ItemSampler(item_set, batch_size=4)
>>> next(iter(item_sampler))
MiniBatch(seed_nodes=None,
    node_pairs={'user:like:item':
        (tensor([0, 2, 4, 6]), tensor([1, 3, 5, 7]))},
    labels=None, negative_srcs=None,
    negative_dsts={'user:like:item': tensor([[10, 11],
    [12, 13],
    [14, 15],
    [16, 17]])}, sampled_subgraphs=None, input_nodes=None,
    node_features=None, edge_features=None, compacted_node_pairs=None,
    compacted_negative_srcs=None, compacted_negative_dsts=None)