Source code for dgl.graphbolt.negative_sampler

"""Negative samplers."""

from _collections_abc import Mapping

from torch.utils.data import functional_datapipe

from .minibatch_transformer import MiniBatchTransformer

__all__ = [
    "NegativeSampler",
]


[docs]@functional_datapipe("sample_negative") class NegativeSampler(MiniBatchTransformer): """ A negative sampler used to generate negative samples and return a mix of positive and negative samples. Functional name: :obj:`sample_negative`. Parameters ---------- datapipe : DataPipe The datapipe. negative_ratio : int The proportion of negative samples to positive samples. """ def __init__( self, datapipe, negative_ratio, ): super().__init__(datapipe, self._sample) assert negative_ratio > 0, "Negative_ratio should be positive Integer." self.negative_ratio = negative_ratio def _sample(self, minibatch): """ Generate a mix of positive and negative samples. If `seeds` in minibatch is not None, `labels` and `indexes` will be constructed after negative sampling, based on corresponding seeds. Parameters ---------- minibatch : MiniBatch An instance of 'MiniBatch' class requires the 'seeds' field. This function is responsible for generating negative edges corresponding to the positive edges defined by the 'seeds'. Returns ------- MiniBatch An instance of 'MiniBatch' encompasses both positive and negative samples. """ seeds = minibatch.seeds if isinstance(seeds, Mapping): if minibatch.indexes is None: minibatch.indexes = {} if minibatch.labels is None: minibatch.labels = {} for etype, pos_pairs in seeds.items(): ( minibatch.seeds[etype], minibatch.labels[etype], minibatch.indexes[etype], ) = self._sample_with_etype(pos_pairs, etype) else: ( minibatch.seeds, minibatch.labels, minibatch.indexes, ) = self._sample_with_etype(seeds) return minibatch def _sample_with_etype(self, seeds, etype=None): """Generate negative pairs for a given etype form positive pairs for a given etype. If `seeds` is a 2D tensor, which represents `seeds` is used in minibatch, corresponding labels and indexes will be constructed. Parameters ---------- seeds : Tensor, Tensor A N*2 tensors that represent source-destination node pairs of positive edges, where positive means the edge must exist in the graph. etype : str Canonical edge type. Returns ------- Tensor A collection of postive and negative node pairs. Tensor Corresponding labels. If label is True, corresponding edge is positive. If label is False, corresponding edge is negative. Tensor Corresponding indexes, indicates to which query an edge belongs. """ raise NotImplementedError