dgl.graphbolt.compact_csc_format

dgl.graphbolt.compact_csc_format(csc_formats: CSCFormatBase | Dict[str, CSCFormatBase], dst_nodes: Tensor | Dict[str, Tensor], dst_timestamps: Tensor | Dict[str, Tensor] | None = None)[source]

Relabel the row (source) IDs in the csc formats into a contiguous range from 0 and return the original row node IDs per type.

Note that 1. The column (destination) IDs are included in the relabeled row IDs. 2. If there are repeated row IDs, they would not be uniqued and will be treated as different nodes. 3. If dst_timestamps is given, the timestamp of each destination node will be broadcasted to its corresponding source nodes.

Parameters:
  • csc_formats (Union[CSCFormatBase, Dict[str, CSCFormatBase]]) – CSC formats representing source-destination edges. - If csc_formats is a CSCFormatBase: It means the graph is homogeneous. Also, indptr and indice in it should be torch.tensor representing source and destination pairs in csc format. And IDs inside are homogeneous ids. - If csc_formats is a Dict[str, CSCFormatBase]: The keys should be edge type and the values should be csc format node pairs. And IDs inside are heterogeneous ids.

  • dst_nodes (Union[torch.Tensor, Dict[str, torch.Tensor]]) – Nodes of all destination nodes in the node pairs. - If dst_nodes is a tensor: It means the graph is homogeneous. - If dst_nodes is a dictionary: The keys are node type and the values are corresponding nodes. And IDs inside are heterogeneous ids.

  • dst_timestamps (Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]]) – Timestamps of all destination nodes in the csc formats. If given, the timestamp of each destination node will be broadcasted to its corresponding source nodes.

Returns:

A tensor of original row node IDs (per type) of all nodes in the input. The compacted CSC formats, where node IDs are replaced with mapped node IDs ranging from 0 to N. The source timestamps (per type) of all nodes in the input if dst_timestamps is given.

Return type:

Tuple[original_row_node_ids, compacted_csc_formats, …]

Examples

>>> import dgl.graphbolt as gb
>>> csc_formats = {
...     "n2:e2:n1": gb.CSCFormatBase(
...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6])
...     ),
...     "n1:e1:n1": gb.CSCFormatBase(
...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3])
...     ),
... }
>>> dst_nodes = {"n1": torch.LongTensor([2, 4])}
>>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(
...     csc_formats, dst_nodes
... )
>>> original_row_node_ids
{'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])}
>>> compacted_csc_formats
{'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
            indices=tensor([0, 1, 2]),
), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
            indices=tensor([2, 3, 4]),
)}
>>> csc_formats = {
...     "n2:e2:n1": gb.CSCFormatBase(
...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([5, 4, 6])
...     ),
...     "n1:e1:n1": gb.CSCFormatBase(
...         indptr=torch.tensor([0, 1, 3]), indices=torch.tensor([1, 2, 3])
...     ),
... }
>>> dst_nodes = {"n1": torch.LongTensor([2, 4])}
>>> original_row_node_ids, compacted_csc_formats = gb.compact_csc_format(
...     csc_formats, dst_nodes
... )
>>> original_row_node_ids
{'n1': tensor([2, 4, 1, 2, 3]), 'n2': tensor([5, 4, 6])}
>>> compacted_csc_formats
{'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
            indices=tensor([0, 1, 2]),
), 'n1:e1:n1': CSCFormatBase(indptr=tensor([0, 1, 3]),
            indices=tensor([2, 3, 4]),
)}
>>> dst_timestamps = {"n1": torch.LongTensor([10, 20])}
>>> (
...     original_row_node_ids,
...     compacted_csc_formats,
...     src_timestamps,
... ) = gb.compact_csc_format(csc_formats, dst_nodes, dst_timestamps)
>>> src_timestamps
{'n1': tensor([10, 20, 10, 20, 20]), 'n2': tensor([10, 20, 20])}