# Source code for dgl.sparse.matmul

```"""Matmul ops for SparseMatrix"""
# pylint: disable=invalid-name
from typing import Union

import torch

from .sparse_matrix import SparseMatrix

__all__ = ["spmm", "bspmm", "spspmm", "matmul"]

[docs]def spmm(A: SparseMatrix, X: torch.Tensor) -> torch.Tensor:
"""Multiplies a sparse matrix by a dense matrix, equivalent to ``A @ X``.

Parameters
----------
A : SparseMatrix
Sparse matrix of shape ``(L, M)`` with scalar values
X : torch.Tensor
Dense matrix of shape ``(M, N)`` or ``(M)``

Returns
-------
torch.Tensor
The dense matrix of shape ``(L, N)`` or ``(L)``

Examples
--------

>>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]])
>>> val = torch.randn(indices.shape[1])
>>> A = dglsp.spmatrix(indices, val)
>>> X = torch.randn(2, 3)
>>> result = dglsp.spmm(A, X)
>>> type(result)
<class 'torch.Tensor'>
>>> result.shape
torch.Size([2, 3])
"""
assert isinstance(
A, SparseMatrix
), f"Expect arg1 to be a SparseMatrix object, got {type(A)}."
assert isinstance(
X, torch.Tensor
), f"Expect arg2 to be a torch.Tensor, got {type(X)}."

[docs]def bspmm(A: SparseMatrix, X: torch.Tensor) -> torch.Tensor:
"""Multiplies a sparse matrix by a dense matrix by batches, equivalent to
``A @ X``.

Parameters
----------
A : SparseMatrix
Sparse matrix of shape ``(L, M)`` with vector values of length ``K``
X : torch.Tensor
Dense matrix of shape ``(M, N, K)``

Returns
-------
torch.Tensor
Dense matrix of shape ``(L, N, K)``

Examples
--------

>>> indices = torch.tensor([[0, 1, 1], [1, 0, 2]])
>>> val = torch.randn(len(row), 2)
>>> A = dglsp.spmatrix(indices, val, shape=(3, 3))
>>> X = torch.randn(3, 3, 2)
>>> result = dglsp.bspmm(A, X)
>>> type(result)
<class 'torch.Tensor'>
>>> result.shape
torch.Size([3, 3, 2])
"""
assert isinstance(
A, SparseMatrix
), f"Expect arg1 to be a SparseMatrix object, got {type(A)}."
assert isinstance(
X, torch.Tensor
), f"Expect arg2 to be a torch.Tensor, got {type(X)}."
return spmm(A, X)

[docs]def spspmm(A: SparseMatrix, B: SparseMatrix) -> SparseMatrix:
"""Multiplies a sparse matrix by a sparse matrix, equivalent to ``A @ B``.

The non-zero values of the two sparse matrices must be 1D.

Parameters
----------
A : SparseMatrix
Sparse matrix of shape ``(L, M)``
B : SparseMatrix
Sparse matrix of shape ``(M, N)``

Returns
-------
SparseMatrix
Sparse matrix of shape ``(L, N)``.

Examples
--------

>>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
>>> val1 = torch.ones(len(row1))
>>> A = dglsp.spmatrix(indices1, val1)
>>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]])
>>> val2 = torch.ones(len(row2))
>>> B = dglsp.spmatrix(indices2, val2)
>>> dglsp.spspmm(A, B)
SparseMatrix(indices=tensor([[0, 0, 1, 1, 1],
[1, 2, 0, 1, 2]]),
values=tensor([1., 1., 1., 1., 1.]),
shape=(2, 3), nnz=5)
"""
assert isinstance(
A, SparseMatrix
), f"Expect A1 to be a SparseMatrix object, got {type(A)}."
assert isinstance(
B, SparseMatrix
), f"Expect A2 to be a SparseMatrix object, got {type(B)}."

return SparseMatrix(
torch.ops.dgl_sparse.spspmm(A.c_sparse_matrix, B.c_sparse_matrix)
)

[docs]def matmul(
A: Union[torch.Tensor, SparseMatrix], B: Union[torch.Tensor, SparseMatrix]
) -> Union[torch.Tensor, SparseMatrix]:
"""Multiplies two dense/sparse matrices, equivalent to ``A @ B``.

This function does not support the case where :attr:`A` is a \
``torch.Tensor`` and :attr:`B` is a ``SparseMatrix``.

* If both matrices are torch.Tensor, it calls \
:func:`torch.matmul()`. The result is a dense matrix.

* If both matrices are sparse, it calls :func:`dgl.sparse.spspmm`. The \
result is a sparse matrix.

* If :attr:`A` is sparse while :attr:`B` is dense, it calls \
:func:`dgl.sparse.spmm`. The result is a dense matrix.

* The operator supports batched sparse-dense matrix multiplication. In \
this case, the sparse matrix :attr:`A` should have shape ``(L, M)``, \
where the non-zero values have a batch dimension ``K``. The dense \
matrix :attr:`B` should have shape ``(M, N, K)``. The output \
is a dense matrix of shape ``(L, N, K)``.

* Sparse-sparse matrix multiplication does not support batched computation.

Parameters
----------
A : torch.Tensor or SparseMatrix
The first matrix.
B : torch.Tensor or SparseMatrix
The second matrix.

Returns
-------
torch.Tensor or SparseMatrix
The result matrix

Examples
--------

Multiplies a diagonal matrix with a dense matrix.

>>> val = torch.randn(3)
>>> A = dglsp.diag(val)
>>> B = torch.randn(3, 2)
>>> result = dglsp.matmul(A, B)
>>> type(result)
<class 'torch.Tensor'>
>>> result.shape
torch.Size([3, 2])

Multiplies a sparse matrix with a dense matrix.

>>> indices = torch.tensor([[0, 1, 1], [1, 0, 1]])
>>> val = torch.randn(indices.shape[1])
>>> A = dglsp.spmatrix(indices, val)
>>> X = torch.randn(2, 3)
>>> result = dglsp.matmul(A, X)
>>> type(result)
<class 'torch.Tensor'>
>>> result.shape
torch.Size([2, 3])

Multiplies a sparse matrix with a sparse matrix.

>>> indices1 = torch.tensor([[0, 1, 1], [1, 0, 1]])
>>> val1 = torch.ones(indices1.shape[1])
>>> A = dglsp.spmatrix(indices1, val1)
>>> indices2 = torch.tensor([[0, 1, 1], [0, 2, 1]])
>>> val2 = torch.ones(indices2.shape[1])
>>> B = dglsp.spmatrix(indices2, val2)
>>> result = dglsp.matmul(A, B)
>>> type(result)
<class 'dgl.sparse.sparse_matrix.SparseMatrix'>
>>> result.shape
(2, 3)
"""
assert isinstance(
A, (torch.Tensor, SparseMatrix)
), f"Expect arg1 to be a torch.Tensor or SparseMatrix, got {type(A)}."
assert isinstance(B, (torch.Tensor, SparseMatrix)), (
f"Expect arg2 to be a torch Tensor or SparseMatrix"
f"object, got {type(B)}."
)
if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):