ClusterGCNSampler

class dgl.dataloading.ClusterGCNSampler(g, k, cache_path='cluster_gcn.pkl', balance_ntypes=None, balance_edges=False, mode='k-way', prefetch_ndata=None, prefetch_edata=None, output_device=None)[source]

Bases: dgl.dataloading.base.Sampler

Cluster sampler from Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks

This sampler first partitions the graph with METIS partitioning, then it caches the nodes of each partition to a file within the given cache directory.

The sampler then selects the graph partitions according to the provided partition IDs, take the union of all nodes in those partitions, and return an induced subgraph in its sample method.

Parameters

Examples

Node classification

With this sampler, the data loader will accept the list of partition IDs as indices to iterate over. For instance, the following code first splits the graph into 1000 partitions using METIS, and at each iteration it gets a subgraph induced by the nodes covered by 20 randomly selected partitions.

>>> num_parts = 1000
>>> sampler = dgl.dataloading.ClusterGCNSampler(g, num_parts)
>>> dataloader = dgl.dataloading.DataLoader(
...     g, torch.arange(num_parts), sampler,
...     batch_size=20, shuffle=True, drop_last=False, num_workers=4)
>>> for subg in dataloader:
...     train_on(subg)