Source code for dgl.graphbolt.minibatch_transformer
"""Mini-batch transformer"""
from torch.utils.data import functional_datapipe
from torchdata.datapipes.iter import Mapper
from .minibatch import MiniBatch
__all__ = [
"MiniBatchTransformer",
]
[docs]@functional_datapipe("transform")
class MiniBatchTransformer(Mapper):
"""A mini-batch transformer used to manipulate mini-batch.
Functional name: :obj:`transform`.
Parameters
----------
datapipe : DataPipe
The datapipe.
transformer:
The function applied to each minibatch which is responsible for
transforming the minibatch.
"""
def __init__(
self,
datapipe,
transformer=None,
):
super().__init__(datapipe, self._transformer)
self.transformer = transformer or self._identity
def _transformer(self, minibatch):
minibatch = self.transformer(minibatch)
assert isinstance(
minibatch, (MiniBatch,)
), "The transformer output should be an instance of MiniBatch"
return minibatch
@staticmethod
def _identity(minibatch):
return minibatch