Source code for dgl.graphbolt.feature_fetcher

"""Feature fetchers"""

from typing import Dict

import torch

from import functional_datapipe

from .base import etype_tuple_to_str

from .minibatch_transformer import MiniBatchTransformer

__all__ = [

[docs]@functional_datapipe("fetch_feature") class FeatureFetcher(MiniBatchTransformer): """A feature fetcher used to fetch features for node/edge in graphbolt. Functional name: :obj:`fetch_feature`. Parameters ---------- datapipe : DataPipe The datapipe. feature_store : FeatureStore A storage for features, support read and update. node_feature_keys : List[str] or Dict[str, List[str]] Node features keys indicates the node features need to be read. - If `node_features` is a list: It means the graph is homogeneous graph, and the 'str' inside are feature names. - If `node_features` is a dictionary: The keys should be node type and the values are lists of feature names. edge_feature_keys : List[str] or Dict[str, List[str]] Edge features name indicates the edge features need to be read. - If `edge_features` is a list: It means the graph is homogeneous graph, and the 'str' inside are feature names. - If `edge_features` is a dictionary: The keys are edge types, following the format 'str:str:str', and the values are lists of feature names. """ def __init__( self, datapipe, feature_store, node_feature_keys=None, edge_feature_keys=None, ): super().__init__(datapipe, self._read) self.feature_store = feature_store self.node_feature_keys = node_feature_keys self.edge_feature_keys = edge_feature_keys = None def _read_data(self, data, stream): """ Fill in the node/edge features field in data. Parameters ---------- data : MiniBatch An instance of :class:`MiniBatch`. Even if 'node_feature' or 'edge_feature' is already filled, it will be overwritten for overlapping features. Returns ------- MiniBatch An instance of :class:`MiniBatch` filled with required features. """ node_features = {} num_layers = data.num_layers() edge_features = [{} for _ in range(num_layers)] is_heterogeneous = isinstance( self.node_feature_keys, Dict ) or isinstance(self.edge_feature_keys, Dict) # Read Node features. input_nodes = data.node_ids() def record_stream(tensor): if stream is not None and tensor.is_cuda: tensor.record_stream(stream) return tensor if self.node_feature_keys and input_nodes is not None: if is_heterogeneous: for type_name, nodes in input_nodes.items(): if type_name not in self.node_feature_keys or nodes is None: continue if nodes.is_cuda: nodes.record_stream(torch.cuda.current_stream()) for feature_name in self.node_feature_keys[type_name]: node_features[ (type_name, feature_name) ] = record_stream( "node", type_name, feature_name, nodes, ) ) else: if input_nodes.is_cuda: input_nodes.record_stream(torch.cuda.current_stream()) for feature_name in self.node_feature_keys: node_features[feature_name] = record_stream( "node", None, feature_name, input_nodes, ) ) # Read Edge features. if self.edge_feature_keys and num_layers > 0: for i in range(num_layers): original_edge_ids = data.edge_ids(i) if original_edge_ids is None: continue if is_heterogeneous: # Convert edge type to string. original_edge_ids = { ( etype_tuple_to_str(key) if isinstance(key, tuple) else key ): value for key, value in original_edge_ids.items() } for type_name, edges in original_edge_ids.items(): if ( type_name not in self.edge_feature_keys or edges is None ): continue if edges.is_cuda: edges.record_stream(torch.cuda.current_stream()) for feature_name in self.edge_feature_keys[type_name]: edge_features[i][ (type_name, feature_name) ] = record_stream( "edge", type_name, feature_name, edges ) ) else: if original_edge_ids.is_cuda: original_edge_ids.record_stream( torch.cuda.current_stream() ) for feature_name in self.edge_feature_keys: edge_features[i][feature_name] = record_stream( "edge", None, feature_name, original_edge_ids, ) ) data.set_node_features(node_features) data.set_edge_features(edge_features) return data def _read(self, data): current_stream = None if is not None: current_stream = torch.cuda.current_stream() with data = self._read_data(data, current_stream) if is not None: data.wait = torch.cuda.current_stream().record_event().wait return data