Source code for dgl.graphbolt.minibatch_transformer

"""Mini-batch transformer"""

from torch.utils.data import functional_datapipe

from torchdata.datapipes.iter import Mapper

from .minibatch import MiniBatch

__all__ = [
    "MiniBatchTransformer",
]


[docs]@functional_datapipe("transform") class MiniBatchTransformer(Mapper): """A mini-batch transformer used to manipulate mini-batch. Functional name: :obj:`transform`. Parameters ---------- datapipe : DataPipe The datapipe. transformer: The function applied to each minibatch which is responsible for transforming the minibatch. """ def __init__( self, datapipe, transformer=None, ): super().__init__(datapipe, self._transformer) self.transformer = transformer or self._identity def _transformer(self, minibatch): minibatch = self.transformer(minibatch) assert isinstance( minibatch, (MiniBatch,) ), "The transformer output should be an instance of MiniBatch" return minibatch @staticmethod def _identity(minibatch): return minibatch