Source code for dgl.ops.gather_mm

"""dgl gather_mm operator module."""
from .. import backend as F

__all__ = ["gather_mm"]


[docs]def gather_mm(a, b, *, idx_b): r"""Gather data according to the given indices and perform matrix multiplication. Let the result tensor be ``c``, the operator conducts the following computation: c[i] = a[i] @ b[idx_b[i]] , where len(c) == len(idx_b) Parameters ---------- a : Tensor A 2-D tensor of shape ``(N, D1)`` b : Tensor A 3-D tensor of shape ``(R, D1, D2)`` idx_b : Tensor, optional An 1-D integer tensor of shape ``(N,)``. Returns ------- Tensor The output dense matrix of shape ``(N, D2)`` """ N, D1 = F.shape(a) R, _, D2 = F.shape(b) if N > 1000000 or D1 > 8 or D2 > 8: # Use segment_mm for large workload import torch sorted_idx_b, perm = torch.sort(idx_b) _, rev_perm = torch.sort(perm) sorted_a = torch.index_select(a, 0, perm) pos_l = torch.searchsorted( sorted_idx_b, torch.arange(R, device=a.device) ) pos_r = torch.cat( [pos_l[1:], torch.tensor([len(idx_b)], device=a.device)] ) seglen = (pos_r - pos_l).cpu() # XXX(minjie): cause device synchronize return torch.index_select( F.segment_mm(sorted_a, b, seglen), 0, rev_perm ) else: return F.gather_mm(a, b, None, idx_b)