Source code for dgl.geometry.edge_coarsening

"""Edge coarsening procedure used in Metis and Graclus, for pytorch"""
# pylint: disable=no-member, invalid-name, W0613
from .. import remove_self_loop
from .capi import _neighbor_matching

__all__ = ["neighbor_matching"]


[docs]def neighbor_matching(graph, e_weights=None, relabel_idx=True): r""" Description ----------- The neighbor matching procedure of edge coarsening in `Metis <http://cacs.usc.edu/education/cs653/Karypis-METIS-SIAMJSC98.pdf>`__ and `Graclus <https://www.cs.utexas.edu/users/inderjit/public_papers/multilevel_pami.pdf>`__ for homogeneous graph coarsening. This procedure keeps picking an unmarked vertex and matching it with one its unmarked neighbors (that maximizes its edge weight) until no match can be done. If no edge weight is given, this procedure will randomly pick neighbor for each vertex. The GPU implementation is based on `A GPU Algorithm for Greedy Graph Matching <http://www.staff.science.uu.nl/~bisse101/Articles/match12.pdf>`__ NOTE: The input graph must be bi-directed (undirected) graph. Call :obj:`dgl.to_bidirected` if you are not sure your graph is bi-directed. Parameters ---------- graph : DGLGraph The input homogeneous graph. edge_weight : torch.Tensor, optional The edge weight tensor holding non-negative scalar weight for each edge. default: :obj:`None` relabel_idx : bool, optional If true, relabel resulting node labels to have consecutive node ids. default: :obj:`True` Examples -------- The following example uses PyTorch backend. >>> import torch, dgl >>> from dgl.geometry import neighbor_matching >>> >>> g = dgl.graph(([0, 1, 1, 2], [1, 0, 2, 1])) >>> res = neighbor_matching(g) tensor([0, 1, 1]) """ assert ( graph.is_homogeneous ), "The graph used in graph node matching must be homogeneous" if e_weights is not None: graph.edata["e_weights"] = e_weights graph = remove_self_loop(graph) e_weights = graph.edata["e_weights"] graph.edata.pop("e_weights") else: graph = remove_self_loop(graph) return _neighbor_matching( graph._graph, graph.num_nodes(), e_weights, relabel_idx )