dgl.ops.segment_mmο
- dgl.ops.segment_mm(a, b, seglen_a)[source]ο
Performs matrix multiplication according to segments.
Suppose
seglen_a == [10, 5, 0, 3]
, the operator will perform four matrix multiplications:a[0:10] @ b[0], a[10:15] @ b[1], a[15:15] @ b[2], a[15:18] @ b[3]
- Parameters:
a (Tensor) β The left operand, 2-D tensor of shape
(N, D1)
b (Tensor) β The right operand, 3-D tensor of shape
(R, D1, D2)
seglen_a (Tensor) β An integer tensor of shape
(R,)
. Each element is the length of segments of inputa
. The summation of all elements must be equal toN
.
- Returns:
The output dense matrix of shape
(N, D2)
- Return type:
Tensor