"""Module for graph partition utilities."""
import time
import numpy as np
from ._ffi.function import _init_api
from .heterograph import DGLHeteroGraph
from . import backend as F
from . import utils
from .base import EID, NID, NTYPE, ETYPE
from .subgraph import edge_subgraph
__all__ = ["metis_partition", "metis_partition_assignment",
"partition_graph_with_halo"]
def reorder_nodes(g, new_node_ids):
""" Generate a new graph with new node IDs.
We assign each node in the input graph with a new node ID. This results in
a new graph.
Parameters
----------
g : DGLGraph
The input graph
new_node_ids : a tensor
The new node IDs
Returns
-------
DGLGraph
The graph with new node IDs.
"""
assert len(new_node_ids) == g.number_of_nodes(), \
"The number of new node ids must match #nodes in the graph."
new_node_ids = utils.toindex(new_node_ids)
sorted_ids, idx = F.sort_1d(new_node_ids.tousertensor())
assert F.asnumpy(sorted_ids[0]) == 0 \
and F.asnumpy(sorted_ids[-1]) == g.number_of_nodes() - 1, \
"The new node IDs are incorrect."
new_gidx = _CAPI_DGLReorderGraph_Hetero(
g._graph, new_node_ids.todgltensor())
new_g = DGLHeteroGraph(gidx=new_gidx, ntypes=['_N'], etypes=['_E'])
new_g.ndata['orig_id'] = idx
return new_g
def _get_halo_heterosubgraph_inner_node(halo_subg):
return _CAPI_GetHaloSubgraphInnerNodes_Hetero(halo_subg)
def reshuffle_graph(g, node_part=None):
'''Reshuffle node ids and edge IDs of a graph.
This function reshuffles nodes and edges in a graph so that all nodes/edges of the same type
have contiguous IDs. If a graph is partitioned and nodes are assigned to different partitions,
all nodes/edges in a partition should
get contiguous IDs; within a partition, all nodes/edges of the same type have contigous IDs.
Parameters
----------
g : DGLGraph
The input graph.
node_part : Tensor
This is a vector whose length is the same as the number of nodes in the input graph.
Each element indicates the partition ID the corresponding node is assigned to.
Returns
-------
(DGLGraph, Tensor)
The graph whose nodes and edges are reshuffled.
The 1D tensor that indicates the partition IDs of the nodes in the reshuffled graph.
'''
# In this case, we don't need to reshuffle node IDs and edge IDs.
if node_part is None:
g.ndata['orig_id'] = F.arange(0, g.number_of_nodes())
g.edata['orig_id'] = F.arange(0, g.number_of_edges())
return g, None
start = time.time()
if node_part is not None:
node_part = utils.toindex(node_part)
node_part = node_part.tousertensor()
if NTYPE in g.ndata:
is_hetero = len(F.unique(g.ndata[NTYPE])) > 1
else:
is_hetero = False
if is_hetero:
num_node_types = F.max(g.ndata[NTYPE], 0) + 1
if node_part is not None:
sorted_part, new2old_map = F.sort_1d(node_part * num_node_types + g.ndata[NTYPE])
else:
sorted_part, new2old_map = F.sort_1d(g.ndata[NTYPE])
sorted_part = F.floor_div(sorted_part, num_node_types)
elif node_part is not None:
sorted_part, new2old_map = F.sort_1d(node_part)
else:
g.ndata['orig_id'] = g.ndata[NID]
g.edata['orig_id'] = g.edata[EID]
return g, None
new_node_ids = np.zeros((g.number_of_nodes(),), dtype=np.int64)
new_node_ids[F.asnumpy(new2old_map)] = np.arange(0, g.number_of_nodes())
# If the input graph is homogneous, we only need to create an empty array, so that
# _CAPI_DGLReassignEdges_Hetero knows how to handle it.
etype = g.edata[ETYPE] if ETYPE in g.edata else F.zeros((0), F.dtype(sorted_part), F.cpu())
g = reorder_nodes(g, new_node_ids)
node_part = utils.toindex(sorted_part)
# We reassign edges in in-CSR. In this way, after partitioning, we can ensure
# that all edges in a partition are in the contiguous ID space.
etype_idx = utils.toindex(etype)
orig_eids = _CAPI_DGLReassignEdges_Hetero(g._graph, etype_idx.todgltensor(),
node_part.todgltensor(), True)
orig_eids = utils.toindex(orig_eids)
orig_eids = orig_eids.tousertensor()
g.edata['orig_id'] = orig_eids
print('Reshuffle nodes and edges: {:.3f} seconds'.format(time.time() - start))
return g, node_part.tousertensor()
[docs]def partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle=False):
'''Partition a graph.
Based on the given node assignments for each partition, the function splits
the input graph into subgraphs. A subgraph may contain HALO nodes which does
not belong to the partition of a subgraph but are connected to the nodes
in the partition within a fixed number of hops.
If `reshuffle` is turned on, the function reshuffles node IDs and edge IDs
of the input graph before partitioning. After reshuffling, all nodes and edges
in a partition fall in a contiguous ID range in the input graph.
The partitioend subgraphs have node data 'orig_id', which stores the node IDs
in the original input graph.
Parameters
------------
g: DGLGraph
The graph to be partitioned
node_part: 1D tensor
Specify which partition a node is assigned to. The length of this tensor
needs to be the same as the number of nodes of the graph. Each element
indicates the partition ID of a node.
extra_cached_hops: int
The number of hops a HALO node can be accessed.
reshuffle : bool
Resuffle nodes so that nodes in the same partition are in the same ID range.
Returns
--------
a dict of DGLGraphs
The key is the partition ID and the value is the DGLGraph of the partition.
Tensor
1D tensor that stores the mapping between the reshuffled node IDs and
the original node IDs if 'reshuffle=True'. Otherwise, return None.
Tensor
1D tensor that stores the mapping between the reshuffled edge IDs and
the original edge IDs if 'reshuffle=True'. Otherwise, return None.
'''
assert len(node_part) == g.number_of_nodes()
if reshuffle:
g, node_part = reshuffle_graph(g, node_part)
orig_nids = g.ndata['orig_id']
orig_eids = g.edata['orig_id']
node_part = utils.toindex(node_part)
start = time.time()
subgs = _CAPI_DGLPartitionWithHalo_Hetero(
g._graph, node_part.todgltensor(), extra_cached_hops)
# g is no longer needed. Free memory.
g = None
print('Split the graph: {:.3f} seconds'.format(time.time() - start))
subg_dict = {}
node_part = node_part.tousertensor()
start = time.time()
# This function determines whether an edge belongs to a partition.
# An edge is assigned to a partition based on its destination node. If its destination node
# is assigned to a partition, we assign the edge to the partition as well.
def get_inner_edge(subg, inner_node):
inner_edge = F.zeros((subg.number_of_edges(),), F.int8, F.cpu())
inner_nids = F.nonzero_1d(inner_node)
# TODO(zhengda) we need to fix utils.toindex() to avoid the dtype cast below.
inner_nids = F.astype(inner_nids, F.int64)
inner_eids = subg.in_edges(inner_nids, form='eid')
inner_edge = F.scatter_row(inner_edge, inner_eids,
F.ones((len(inner_eids),), F.dtype(inner_edge), F.cpu()))
return inner_edge
# This creaets a subgraph from subgraphs returned from the CAPI above.
def create_subgraph(subg, induced_nodes, induced_edges, inner_node):
subg1 = DGLHeteroGraph(gidx=subg.graph, ntypes=['_N'], etypes=['_E'])
# If IDs are shuffled, we should shuffled edges. This will help us collect edge data
# from the distributed graph after training.
if reshuffle:
# When we shuffle edges, we need to make sure that the inner edges are assigned with
# contiguous edge IDs and their ID range starts with 0. In other words, we want to
# place these edge IDs in the front of the edge list. To ensure that, we add the IDs
# of outer edges with a large value, so we will get the sorted list as we want.
max_eid = F.max(induced_edges[0], 0) + 1
inner_edge = get_inner_edge(subg1, inner_node)
eid = F.astype(induced_edges[0], F.int64) + max_eid * F.astype(inner_edge == 0, F.int64)
_, index = F.sort_1d(eid)
subg1 = edge_subgraph(subg1, index, relabel_nodes=False)
subg1.ndata[NID] = induced_nodes[0]
subg1.edata[EID] = F.gather_row(induced_edges[0], index)
else:
subg1.ndata[NID] = induced_nodes[0]
subg1.edata[EID] = induced_edges[0]
return subg1
for i, subg in enumerate(subgs):
inner_node = _get_halo_heterosubgraph_inner_node(subg)
inner_node = F.zerocopy_from_dlpack(inner_node.to_dlpack())
subg = create_subgraph(subg, subg.induced_nodes, subg.induced_edges, inner_node)
subg.ndata['inner_node'] = inner_node
subg.ndata['part_id'] = F.gather_row(node_part, subg.ndata[NID])
if reshuffle:
subg.ndata['orig_id'] = F.gather_row(orig_nids, subg.ndata[NID])
subg.edata['orig_id'] = F.gather_row(orig_eids, subg.edata[EID])
if extra_cached_hops >= 1:
inner_edge = get_inner_edge(subg, inner_node)
else:
inner_edge = F.ones((subg.number_of_edges(),), F.int8, F.cpu())
subg.edata['inner_edge'] = inner_edge
subg_dict[i] = subg
print('Construct subgraphs: {:.3f} seconds'.format(time.time() - start))
if reshuffle:
return subg_dict, orig_nids, orig_eids
else:
return subg_dict, None, None
[docs]def metis_partition_assignment(g, k, balance_ntypes=None, balance_edges=False, mode="k-way"):
''' This assigns nodes to different partitions with Metis partitioning algorithm.
When performing Metis partitioning, we can put some constraint on the partitioning.
Current, it supports two constrants to balance the partitioning. By default, Metis
always tries to balance the number of nodes in each partition.
* `balance_ntypes` balances the number of nodes of different types in each partition.
* `balance_edges` balances the number of edges in each partition.
To balance the node types, a user needs to pass a vector of N elements to indicate
the type of each node. N is the number of nodes in the input graph.
After the partition assignment, we construct partitions.
Parameters
----------
g : DGLGraph
The graph to be partitioned
k : int
The number of partitions.
balance_ntypes : tensor
Node type of each node
balance_edges : bool
Indicate whether to balance the edges.
mode : str, "k-way" or "recursive"
Whether use multilevel recursive bisection or multilevel k-way paritioning.
Returns
-------
a 1-D tensor
A vector with each element that indicates the partition ID of a vertex.
'''
assert mode in ("k-way", "recursive"), "'mode' can only be 'k-way' or 'recursive'"
assert g.idtype == F.int64, "IdType of graph is required to be int64 for now."
# METIS works only on symmetric graphs.
# The METIS runs on the symmetric graph to generate the node assignment to partitions.
start = time.time()
sym_gidx = _CAPI_DGLMakeSymmetric_Hetero(g._graph)
sym_g = DGLHeteroGraph(gidx=sym_gidx)
print('Convert a graph into a bidirected graph: {:.3f} seconds'.format(
time.time() - start))
vwgt = []
# To balance the node types in each partition, we can take advantage of the vertex weights
# in Metis. When vertex weights are provided, Metis will tries to generate partitions with
# balanced vertex weights. A vertex can be assigned with multiple weights. The vertex weights
# are stored in a vector of N * w elements, where N is the number of vertices and w
# is the number of weights per vertex. Metis tries to balance the first weight, and then
# the second weight, and so on.
# When balancing node types, we use the first weight to indicate the first node type.
# if a node belongs to the first node type, its weight is set to 1; otherwise, 0.
# Similary, we set the second weight for the second node type and so on. The number
# of weights is the same as the number of node types.
start = time.time()
if balance_ntypes is not None:
assert len(balance_ntypes) == g.number_of_nodes(), \
"The length of balance_ntypes should be equal to #nodes in the graph"
balance_ntypes = F.tensor(balance_ntypes)
uniq_ntypes = F.unique(balance_ntypes)
for ntype in uniq_ntypes:
vwgt.append(F.astype(balance_ntypes == ntype, F.int64))
# When balancing edges in partitions, we use in-degree as one of the weights.
if balance_edges:
if balance_ntypes is None:
vwgt.append(F.astype(g.in_degrees(), F.int64))
else:
for ntype in uniq_ntypes:
nids = F.asnumpy(F.nonzero_1d(balance_ntypes == ntype))
degs = np.zeros((g.number_of_nodes(),), np.int64)
degs[nids] = F.asnumpy(g.in_degrees(nids))
vwgt.append(F.zerocopy_from_numpy(degs))
# The vertex weights have to be stored in a vector.
if len(vwgt) > 0:
vwgt = F.stack(vwgt, 1)
shape = (np.prod(F.shape(vwgt),),)
vwgt = F.reshape(vwgt, shape)
vwgt = F.to_dgl_nd(vwgt)
print(
'Construct multi-constraint weights: {:.3f} seconds'.format(time.time() - start))
else:
vwgt = F.zeros((0,), F.int64, F.cpu())
vwgt = F.to_dgl_nd(vwgt)
start = time.time()
node_part = _CAPI_DGLMetisPartition_Hetero(sym_g._graph, k, vwgt, mode)
print('Metis partitioning: {:.3f} seconds'.format(time.time() - start))
if len(node_part) == 0:
return None
else:
node_part = utils.toindex(node_part)
return node_part.tousertensor()
[docs]def metis_partition(g, k, extra_cached_hops=0, reshuffle=False,
balance_ntypes=None, balance_edges=False, mode="k-way"):
''' This is to partition a graph with Metis partitioning.
Metis assigns vertices to partitions. This API constructs subgraphs with the vertices assigned
to the partitions and their incoming edges. A subgraph may contain HALO nodes which does
not belong to the partition of a subgraph but are connected to the nodes
in the partition within a fixed number of hops.
When performing Metis partitioning, we can put some constraint on the partitioning.
Current, it supports two constrants to balance the partitioning. By default, Metis
always tries to balance the number of nodes in each partition.
* `balance_ntypes` balances the number of nodes of different types in each partition.
* `balance_edges` balances the number of edges in each partition.
To balance the node types, a user needs to pass a vector of N elements to indicate
the type of each node. N is the number of nodes in the input graph.
If `reshuffle` is turned on, the function reshuffles node IDs and edge IDs
of the input graph before partitioning. After reshuffling, all nodes and edges
in a partition fall in a contiguous ID range in the input graph.
The partitioend subgraphs have node data 'orig_id', which stores the node IDs
in the original input graph.
The partitioned subgraph is stored in DGLGraph. The DGLGraph has the `part_id`
node data that indicates the partition a node belongs to. The subgraphs do not contain
the node/edge data in the input graph.
Parameters
------------
g: DGLGraph
The graph to be partitioned
k: int
The number of partitions.
extra_cached_hops: int
The number of hops a HALO node can be accessed.
reshuffle : bool
Resuffle nodes so that nodes in the same partition are in the same ID range.
balance_ntypes : tensor
Node type of each node
balance_edges : bool
Indicate whether to balance the edges.
mode : str, "k-way" or "recursive"
Whether use multilevel recursive bisection or multilevel k-way paritioning.
Returns
--------
a dict of DGLGraphs
The key is the partition ID and the value is the DGLGraph of the partition.
'''
assert mode in ("k-way", "recursive"), "'mode' can only be 'k-way' or 'recursive'"
node_part = metis_partition_assignment(g, k, balance_ntypes, balance_edges, mode)
if node_part is None:
return None
# Then we split the original graph into parts based on the METIS partitioning results.
return partition_graph_with_halo(g, node_part, extra_cached_hops, reshuffle)[0]
class NDArrayPartition(object):
""" Create a new partition of an NDArray. That is, an object which assigns
each row of an NDArray to a specific partition.
Parameters
----------
array_size : int
The first dimension of the array being partitioned.
num_parts : int
The number of parts to divide the array into.
mode : String
The type of partition. Currently, the only valid value is 'remainder',
which assigns rows based on remainder when dividing the row id by the
number of parts (e.g., i % num_parts).
part_ranges : List
Currently unused.
Examples
--------
A partition of a homgeonous graph `g`, where the vertices are
striped across processes can be generated via:
>>> from dgl.partition import NDArrayPartition
>>> part = NDArrayPartition(g.num_nodes(), num_parts, mode='remainder' )
"""
def __init__(self, array_size, num_parts, mode='remainder', part_ranges=None):
assert num_parts > 0, 'Invalid "num_parts", must be > 0.'
if mode == 'remainder':
assert part_ranges is None, 'When using remainder-based ' \
'partitioning, "part_ranges" should not be specified.'
self._partition = _CAPI_DGLNDArrayPartitionCreateRemainderBased(
array_size, num_parts)
else:
assert False, 'Unknown partition mode "{}"'.format(mode)
self._array_size = array_size
self._num_parts = num_parts
def num_parts(self):
""" Get the number of partitions.
"""
return self._num_parts
def array_size(self):
""" Get the total size of the first dimension of the partitioned array.
"""
return self._array_size
def get(self):
""" Get the C-handle for this object.
"""
return self._partition
def get_local_indices(self, part, ctx):
""" Get the set of global indices in this given partition.
"""
return self.map_to_global(F.arange(0, self.local_size(part), ctx=ctx), part)
def local_size(self, part):
""" Get the number of rows/items assigned to the given part.
"""
return _CAPI_DGLNDArrayPartitionGetPartSize(self._partition, part)
def map_to_local(self, idxs):
""" Convert the set of global indices to local indices
"""
return F.zerocopy_from_dgl_ndarray(_CAPI_DGLNDArrayPartitionMapToLocal(
self._partition,
F.zerocopy_to_dgl_ndarray(idxs)))
def map_to_global(self, idxs, part_id):
""" Convert the set of local indices ot global indices
"""
return F.zerocopy_from_dgl_ndarray(_CAPI_DGLNDArrayPartitionMapToGlobal(
self._partition, F.zerocopy_to_dgl_ndarray(idxs), part_id))
_init_api("dgl.partition")