MiniGCDataset

class dgl.data.MiniGCDataset(num_graphs, min_num_v, max_num_v, seed=0, save_graph=True, force_reload=False, verbose=False, transform=None)[source]

Bases: dgl.data.dgl_dataset.DGLDataset

The synthetic graph classification dataset class.

The datset contains 8 different types of graphs.

  • class 0 : cycle graph

  • class 1 : star graph

  • class 2 : wheel graph

  • class 3 : lollipop graph

  • class 4 : hypercube graph

  • class 5 : grid graph

  • class 6 : clique graph

  • class 7 : circular ladder graph

Parameters
  • num_graphs (int) – Number of graphs in this dataset.

  • min_num_v (int) – Minimum number of nodes for graphs

  • max_num_v (int) – Maximum number of nodes for graphs

  • seed (int, default is 0) – Random seed for data generation

  • transform (callable, optional) – A transform that takes in a DGLGraph object and returns a transformed version. The DGLGraph object will be transformed before every access.

num_graphs

Number of graphs

Type

int

min_num_v

The minimum number of nodes

Type

int

max_num_v

The maximum number of nodes

Type

int

num_classes

The number of classes

Type

int

Examples

>>> data = MiniGCDataset(100, 16, 32, seed=0)

The dataset instance is an iterable

>>> len(data)
100
>>> g, label = data[64]
>>> g
Graph(num_nodes=20, num_edges=82,
      ndata_schemes={}
      edata_schemes={})
>>> label
tensor(5)

Batch the graphs and labels for mini-batch training

>>> graphs, labels = zip(*[data[i] for i in range(16)])
>>> batched_graphs = dgl.batch(graphs)
>>> batched_labels = torch.tensor(labels)
>>> batched_graphs
Graph(num_nodes=356, num_edges=1060,
      ndata_schemes={}
      edata_schemes={})
__getitem__(idx)[source]

Get the idx-th sample.

Parameters

idx (int) – The sample index.

Returns

The graph and its label.

Return type

(dgl.Graph, Tensor)

__len__()[source]

Return the number of graphs in the dataset.