ClusterGCNSampler¶
-
class
dgl.dataloading.
ClusterGCNSampler
(g, k, cache_path='cluster_gcn.pkl', balance_ntypes=None, balance_edges=False, mode='k-way', prefetch_ndata=None, prefetch_edata=None, output_device=None)[source]¶ Bases:
dgl.dataloading.base.Sampler
Cluster sampler from Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks
This sampler first partitions the graph with METIS partitioning, then it caches the nodes of each partition to a file within the given cache directory.
The sampler then selects the graph partitions according to the provided partition IDs, take the union of all nodes in those partitions, and return an induced subgraph in its
sample
method.- Parameters
g (DGLGraph) – The original graph. Must be homogeneous and on CPU.
k (int) – The number of partitions.
cache_path (str) – The path to the cache directory for storing the partition result.
balance_ntypes – Passed to
dgl.metis_partition_assignment()
.balkance_edges – Passed to
dgl.metis_partition_assignment()
.mode – Passed to
dgl.metis_partition_assignment()
.prefetch_ndata (list[str], optional) –
The node data to prefetch for the subgraph.
See 6.8 Feature Prefetching for a detailed explanation of prefetching.
prefetch_edata (list[str], optional) –
The edge data to prefetch for the subgraph.
See 6.8 Feature Prefetching for a detailed explanation of prefetching.
output_device (device, optional) – The device of the output subgraphs or MFGs. Default is the same as the minibatch of partition indices.
Examples
Node classification
With this sampler, the data loader will accept the list of partition IDs as indices to iterate over. For instance, the following code first splits the graph into 1000 partitions using METIS, and at each iteration it gets a subgraph induced by the nodes covered by 20 randomly selected partitions.
>>> num_parts = 1000 >>> sampler = dgl.dataloading.ClusterGCNSampler(g, num_parts) >>> dataloader = dgl.dataloading.DataLoader( ... g, torch.arange(num_parts), sampler, ... batch_size=20, shuffle=True, drop_last=False, num_workers=4) >>> for subg in dataloader: ... train_on(subg)