"""GraphBolt Itemset."""
import textwrap
from typing import Dict, Iterable, Tuple, Union
import torch
__all__ = ["ItemSet", "ItemSetDict"]
def is_scalar(x):
"""Checks if the input is a scalar."""
return (
len(x.shape) == 0 if isinstance(x, torch.Tensor) else isinstance(x, int)
)
[docs]class ItemSet:
r"""A wrapper of a tensor or tuple of tensors.
Parameters
----------
items: Union[int, torch.Tensor, Tuple[torch.Tensor]]
The tensors to be wrapped.
- If it is a single scalar (an integer or a tensor that holds a single
value), the item would be considered as a range_tensor created by
`torch.arange`.
- If it is a multi-dimensional tensor, the indexing will be performed
along the first dimension.
- If it is a tuple, each item in the tuple must be a tensor.
names: Union[str, Tuple[str]], optional
The names of the items. If it is a tuple, each name must corresponds to
an item in the `items` parameter. The naming is arbitrary, but in
general practice, the names should be chosen from ['labels', 'seeds',
'indexes'] to align with the attributes of class
`dgl.graphbolt.MiniBatch`.
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Integer: number of nodes.
>>> num = 10
>>> item_set = gb.ItemSet(num, names="seeds")
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4), tensor(5),
tensor(6), tensor(7), tensor(8), tensor(9)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> item_set.names
('seeds',)
2. Torch scalar: number of nodes. Customizable dtype compared to Integer.
>>> num = torch.tensor(10, dtype=torch.int32)
>>> item_set = gb.ItemSet(num, names="seeds")
>>> list(item_set)
[tensor(0, dtype=torch.int32), tensor(1, dtype=torch.int32),
tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32),
tensor(4, dtype=torch.int32), tensor(5, dtype=torch.int32),
tensor(6, dtype=torch.int32), tensor(7, dtype=torch.int32),
tensor(8, dtype=torch.int32), tensor(9, dtype=torch.int32)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)
>>> item_set.names
('seeds',)
3. Single tensor: seed nodes.
>>> node_ids = torch.arange(0, 5)
>>> item_set = gb.ItemSet(node_ids, names="seeds")
>>> list(item_set)
[tensor(0), tensor(1), tensor(2), tensor(3), tensor(4)]
>>> item_set[:]
tensor([0, 1, 2, 3, 4])
>>> item_set.names
('seeds',)
4. Tuple of tensors with same shape: seed nodes and labels.
>>> node_ids = torch.arange(0, 5)
>>> labels = torch.arange(5, 10)
>>> item_set = gb.ItemSet(
... (node_ids, labels), names=("seeds", "labels"))
>>> list(item_set)
[(tensor(0), tensor(5)), (tensor(1), tensor(6)), (tensor(2), tensor(7)),
(tensor(3), tensor(8)), (tensor(4), tensor(9))]
>>> item_set[:]
(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9]))
>>> item_set.names
('seeds', 'labels')
5. Tuple of tensors with different shape: seeds and labels.
>>> seeds = torch.arange(0, 10).reshape(-1, 2)
>>> labels = torch.tensor([1, 1, 0, 0, 0])
>>> item_set = gb.ItemSet(
... (seeds, labels), names=("seeds", "lables"))
>>> list(item_set)
[(tensor([0, 1]), tensor([1])),
(tensor([2, 3]), tensor([1])),
(tensor([4, 5]), tensor([0])),
(tensor([6, 7]), tensor([0])),
(tensor([8, 9]), tensor([0]))]
>>> item_set[:]
(tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]),
tensor([1, 1, 0, 0, 0]))
>>> item_set.names
('seeds', 'labels')
6. Tuple of tensors with different shape: hyperlink and labels.
>>> seeds = torch.arange(0, 10).reshape(-1, 5)
>>> labels = torch.tensor([1, 0])
>>> item_set = gb.ItemSet(
... (seeds, labels), names=("seeds", "lables"))
>>> list(item_set)
[(tensor([0, 1, 2, 3, 4]), tensor([1])),
(tensor([5, 6, 7, 8, 9]), tensor([0]))]
>>> item_set[:]
(tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
tensor([1, 0]))
>>> item_set.names
('seeds', 'labels')
"""
def __init__(
self,
items: Union[int, torch.Tensor, Tuple[torch.Tensor]],
names: Union[str, Tuple[str]] = None,
) -> None:
if is_scalar(items):
self._length = int(items)
self._items = items
elif isinstance(items, tuple):
self._length = len(items[0])
if any(self._length != len(item) for item in items):
raise ValueError("Size mismatch between items.")
self._items = items
else:
self._length = len(items)
self._items = (items,)
self._num_items = (
len(self._items) if isinstance(self._items, tuple) else 1
)
if names is not None:
if isinstance(names, tuple):
self._names = names
else:
self._names = (names,)
assert self._num_items == len(self._names), (
f"Number of items ({self._num_items}) and "
f"names ({len(self._names)}) don't match."
)
else:
self._names = None
def __len__(self) -> int:
return self._length
def __getitem__(self, index: Union[int, slice, Iterable[int]]):
if is_scalar(self._items):
dtype = getattr(self._items, "dtype", torch.int64)
if isinstance(index, slice):
start, stop, step = index.indices(self._length)
return torch.arange(start, stop, step, dtype=dtype)
elif isinstance(index, int):
if index < 0:
index += self._length
if index < 0 or index >= self._length:
raise IndexError(
f"{type(self).__name__} index out of range."
)
return torch.tensor(index, dtype=dtype)
elif isinstance(index, Iterable):
return torch.tensor(index, dtype=dtype)
else:
raise TypeError(
f"{type(self).__name__} indices must be int, slice, or "
f"iterable of int, not {type(index)}."
)
elif self._num_items == 1:
return self._items[0][index]
else:
return tuple(item[index] for item in self._items)
@property
def names(self) -> Tuple[str]:
"""Return the names of the items."""
return self._names
@property
def num_items(self) -> int:
"""Return the number of the items."""
return self._num_items
def __repr__(self) -> str:
ret = (
f"{self.__class__.__name__}(\n"
f" items={self._items},\n"
f" names={self._names},\n"
f")"
)
return ret
[docs]class ItemSetDict:
r"""Dictionary wrapper of **ItemSet**.
This class aims to assemble existing itemsets with different tags, for
example, seed_nodes of different node types in a graph.
Parameters
----------
itemsets: Dict[str, ItemSet]
Examples
--------
>>> import torch
>>> from dgl import graphbolt as gb
1. Each itemset is a single tensor: seed nodes.
>>> node_ids_user = torch.arange(0, 5)
>>> node_ids_item = torch.arange(5, 10)
>>> item_set = gb.ItemSetDict({
... "user": gb.ItemSet(node_ids_user, names="seeds"),
... "item": gb.ItemSet(node_ids_item, names="seeds")})
>>> list(item_set)
[{"user": tensor(0)}, {"user": tensor(1)}, {"user": tensor(2)},
{"user": tensor(3)}, {"user": tensor(4)}, {"item": tensor(5)},
{"item": tensor(6)}, {"item": tensor(7)}, {"item": tensor(8)},
{"item": tensor(9)}}]
>>> item_set[:]
{"user": tensor([0, 1, 2, 3, 4]), "item": tensor([5, 6, 7, 8, 9])}
>>> item_set.names
('seeds',)
2. Each itemset is a tuple of tensors with same shape: seed nodes and
labels.
>>> node_ids_user = torch.arange(0, 2)
>>> labels_user = torch.arange(0, 2)
>>> node_ids_item = torch.arange(2, 5)
>>> labels_item = torch.arange(2, 5)
>>> item_set = gb.ItemSetDict({
... "user": gb.ItemSet(
... (node_ids_user, labels_user),
... names=("seeds", "labels")),
... "item": gb.ItemSet(
... (node_ids_item, labels_item),
... names=("seeds", "labels"))})
>>> list(item_set)
[{"user": (tensor(0), tensor(0))}, {"user": (tensor(1), tensor(1))},
{"item": (tensor(2), tensor(2))}, {"item": (tensor(3), tensor(3))},
{"item": (tensor(4), tensor(4))}}]
>>> item_set[:]
{"user": (tensor([0, 1]), tensor([0, 1])),
"item": (tensor([2, 3, 4]), tensor([2, 3, 4]))}
>>> item_set.names
('seeds', 'labels')
3. Each itemset is a tuple of tensors with different shape: seeds and
labels.
>>> seeds_like = torch.arange(0, 4).reshape(-1, 2)
>>> labels_like = torch.tensor([1, 0])
>>> seeds_follow = torch.arange(0, 6).reshape(-1, 2)
>>> labels_follow = torch.tensor([1, 1, 0])
>>> item_set = gb.ItemSetDict({
... "user:like:item": gb.ItemSet(
... (seeds_like, labels_like),
... names=("seeds", "labels")),
... "user:follow:user": gb.ItemSet(
... (seeds_follow, labels_follow),
... names=("seeds", "labels"))})
>>> list(item_set)
[{'user:like:item': (tensor([0, 1]), tensor(1))},
{'user:like:item': (tensor([2, 3]), tensor(0))},
{'user:follow:user': (tensor([0, 1]), tensor(1))},
{'user:follow:user': (tensor([2, 3]), tensor(1))},
{'user:follow:user': (tensor([4, 5]), tensor(0))}]
>>> item_set[:]
{'user:like:item': (tensor([[0, 1], [2, 3]]),
tensor([1, 0])),
'user:follow:user': (tensor([[0, 1], [2, 3], [4, 5]]),
tensor([1, 1, 0]))}
>>> item_set.names
('seeds', 'labels')
4. Each itemset is a tuple of tensors with different shape: hyperlink and
labels.
>>> first_seeds = torch.arange(0, 6).reshape(-1, 3)
>>> first_labels = torch.tensor([1, 0])
>>> second_seeds = torch.arange(0, 2).reshape(-1, 1)
>>> second_labels = torch.tensor([1, 0])
>>> item_set = gb.ItemSetDict({
... "query:user:item": gb.ItemSet(
... (first_seeds, first_labels),
... names=("seeds", "labels")),
... "user": gb.ItemSet(
... (second_seeds, second_labels),
... names=("seeds", "labels"))})
>>> list(item_set)
[{'query:user:item': (tensor([0, 1, 2]), tensor(1))},
{'query:user:item': (tensor([3, 4, 5]), tensor(0))},
{'user': (tensor([0]), tensor(1))},
{'user': (tensor([1]), tensor(0))}]
>>> item_set[:]
{'query:user:item': (tensor([[0, 1, 2], [3, 4, 5]]),
tensor([1, 0])),
'user': (tensor([[0], [1]]),tensor([1, 0]))}
>>> item_set.names
('seeds', 'labels')
"""
def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
self._itemsets = itemsets
self._names = next(iter(itemsets.values())).names
assert all(
self._names == itemset.names for itemset in itemsets.values()
), "All itemsets must have the same names."
offset = [0] + [len(itemset) for itemset in self._itemsets.values()]
self._offsets = torch.tensor(offset).cumsum(0)
self._length = int(self._offsets[-1])
self._keys = list(self._itemsets.keys())
def __len__(self) -> int:
return self._length
def __getitem__(self, index: Union[int, slice, Iterable[int]]):
if isinstance(index, int):
if index < 0:
index += self._length
if index < 0 or index >= self._length:
raise IndexError(f"{type(self).__name__} index out of range.")
offset_idx = torch.searchsorted(self._offsets, index, right=True)
offset_idx -= 1
index -= self._offsets[offset_idx]
key = self._keys[offset_idx]
return {key: self._itemsets[key][index]}
elif isinstance(index, slice):
start, stop, step = index.indices(self._length)
if step != 1:
return self.__getitem__(torch.arange(start, stop, step))
assert start < stop, "Start must be smaller than stop."
data = {}
offset_idx_start = max(
1, torch.searchsorted(self._offsets, start, right=False)
)
for offset_idx in range(offset_idx_start, len(self._offsets)):
key = self._keys[offset_idx - 1]
data[key] = self._itemsets[key][
max(0, start - self._offsets[offset_idx - 1]) : stop
- self._offsets[offset_idx - 1]
]
if stop <= self._offsets[offset_idx]:
break
return data
elif isinstance(index, Iterable):
# TODO[Mingbang]: Might have performance issue. Tests needed.
data = {key: [] for key in self._keys}
for idx in index:
if idx < 0:
idx += self._length
if idx < 0 or idx >= self._length:
raise IndexError(
f"{type(self).__name__} index out of range."
)
offset_idx = torch.searchsorted(self._offsets, idx, right=True)
offset_idx -= 1
idx -= self._offsets[offset_idx]
key = self._keys[offset_idx]
data[key].append(int(idx))
for key in self._keys:
indices = data[key]
if len(indices) == 0:
del data[key]
continue
item_set = self._itemsets[key]
try:
value = item_set[indices]
except TypeError:
# In case the itemset doesn't support list indexing.
value = tuple(item_set[idx] for idx in indices)
finally:
data[key] = value
return data
else:
raise TypeError(
f"{type(self).__name__} indices must be int, slice, or "
f"iterable of int, not {type(index)}."
)
@property
def names(self) -> Tuple[str]:
"""Return the names of the items."""
return self._names
def __repr__(self) -> str:
ret = (
"{Classname}(\n"
" itemsets={itemsets},\n"
" names={names},\n"
")"
)
itemsets_str = textwrap.indent(
repr(self._itemsets), " " * len(" itemsets=")
).strip()
return ret.format(
Classname=self.__class__.__name__,
itemsets=itemsets_str,
names=self._names,
)