ItemSetDict

class dgl.graphbolt.ItemSetDict(itemsets: Dict[str, ItemSet])[source]

Bases: object

Dictionary wrapper of ItemSet.

Each item is retrieved by iterating over each itemset and returned with corresponding key as a dict.

Parameters:

itemsets (Dict[str, ItemSet]) –

Examples

>>> import torch
>>> from dgl import graphbolt as gb
  1. Single iterable: seed nodes.

>>> node_ids_user = torch.arange(0, 5)
>>> node_ids_item = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({
...     "user": gb.ItemSet(node_ids_user, names="seed_nodes"),
...     "item": gb.ItemSet(node_ids_item, names="seed_nodes")})
>>> list(item_set)
[{"user": tensor(0)}, {"user": tensor(1)}, {"user": tensor(2)},
 {"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)},
 {"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)},
 {"item": tensor(9)}}]
>>> item_set[:]
{"user": tensor([0, 1, 2, 3, 4]), "item": tensor([5, 6, 7, 8, 9])}
>>> item_set.names
('seed_nodes',)
  1. Tuple of iterables with same shape: seed nodes and labels.

>>> node_ids_user = torch.arange(0, 2)
>>> labels_user = torch.arange(0, 2)
>>> node_ids_item = torch.arange(2, 5)
>>> labels_item = torch.arange(2, 5)
>>> item_set = gb.ItemSetDict({
...     "user": gb.ItemSet(
...         (node_ids_user, labels_user),
...         names=("seed_nodes", "labels")),
...     "item": gb.ItemSet(
...         (node_ids_item, labels_item),
...         names=("seed_nodes", "labels"))})
>>> list(item_set)
[{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))},
 {"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))},
 {"item": (tensor(4), tensor(4))}}]
>>> item_set[:]
{"user": (tensor([0, 1]), tensor([0, 1])),
 "item": (tensor([2, 3, 4]), tensor([2, 3, 4]))}
>>> item_set.names
('seed_nodes', 'labels')
  1. Tuple of iterables with different shape: node pairs and negative dsts.

>>> node_pairs_like = torch.arange(0, 4).reshape(-1, 2)
>>> neg_dsts_like = torch.arange(4, 10).reshape(-1, 3)
>>> node_pairs_follow = torch.arange(0, 6).reshape(-1, 2)
>>> neg_dsts_follow = torch.arange(6, 15).reshape(-1, 3)
>>> item_set = gb.ItemSetDict({
...     "user:like:item": gb.ItemSet(
...         (node_pairs_like, neg_dsts_like),
...         names=("node_pairs", "negative_dsts")),
...     "user:follow:user": gb.ItemSet(
...         (node_pairs_follow, neg_dsts_follow),
...         names=("node_pairs", "negative_dsts"))})
>>> list(item_set)
[{"user:like:item": (tensor([0, 1]), tensor([4, 5, 6]))},
 {"user:like:item": (tensor([2, 3]), tensor([7, 8, 9]))},
 {"user:follow:user": (tensor([0, 1]), tensor([ 6,  7,  8,  9, 10, 11]))},
 {"user:follow:user": (tensor([2, 3]), tensor([12, 13, 14, 15, 16, 17]))},
 {"user:follow:user": (tensor([4, 5]), tensor([18, 19, 20, 21, 22, 23]))}]
>>> item_set[:]
{"user:like:item": (tensor([[0, 1], [2, 3]]),
                    tensor([[4, 5, 6], [7, 8, 9]])),
 "user:follow:user": (tensor([[0, 1], [2, 3], [4, 5]]),
                      tensor([[ 6,  7,  8,  9, 10, 11],
                              [12, 13, 14, 15, 16, 17],
                              [18, 19, 20, 21, 22, 23]]))}
>>> item_set.names
('node_pairs', 'negative_dsts')
property names: Tuple[str]

Return the names of the items.