dgl.unbatch

dgl.unbatch(g, node_split=None, edge_split=None)[source]

Revert the batch operation by split the given graph into a list of small ones.

This is the reverse operation of :func:dgl.batch. If the node_split or the edge_split is not given, it calls DGLGraph.batch_num_nodes() and DGLGraph.batch_num_edges() of the input graph to get the information.

If the node_split or the edge_split arguments are given, it will partition the graph according to the given segments. One must assure that the partition is valid – edges of the i^th graph only connect nodes belong to the i^th graph. Otherwise, DGL will throw an error.

The function supports heterograph input, in which case the two split section arguments shall be of dictionary type – similar to the DGLGraph.batch_num_nodes() and DGLGraph.batch_num_edges() attributes of a heterograph.

Parameters
  • g (DGLGraph) – Input graph to unbatch.

  • node_split (Tensor, dict[str, Tensor], optional) – Number of nodes of each result graph.

  • edge_split (Tensor, dict[str, Tensor], optional) – Number of edges of each result graph.

Returns

Unbatched list of graphs.

Return type

list[DGLGraph]

Examples

Unbatch a batched graph

>>> import dgl
>>> import torch as th
>>> # 4 nodes, 3 edges
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> # 3 nodes, 4 edges
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> # add features
>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)
>>> g1.edata['w'] = th.ones(g1.num_edges(), 2)
>>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)
>>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)
>>> bg = dgl.batch([g1, g2])
>>> f1, f2 = dgl.unbatch(bg)
>>> f1
Graph(num_nodes=4, num_edges=3,
      ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})
>>> f2
Graph(num_nodes=3, num_edges=4,
      ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})

With provided split arguments:

>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> g3 = dgl.graph((th.tensor([0]), th.tensor([1])))
>>> bg = dgl.batch([g1, g2, g3])
>>> bg.batch_num_nodes()
tensor([4, 3, 2])
>>> bg.batch_num_edges()
tensor([3, 4, 1])
>>> # unbatch but merge g2 and g3
>>> f1, f2 = dgl.unbatch(bg, th.tensor([4, 5]), th.tensor([3, 5]))
>>> f1
Graph(num_nodes=4, num_edges=3,
      ndata_schemes={}
      edata_schemes={})
>>> f2
Graph(num_nodes=5, num_edges=5,
      ndata_schemes={}
      edata_schemes={})

Heterograph input

>>> hg1 = dgl.heterograph({
...     ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
...     ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})
>>> bhg = dgl.batch([hg1, hg2])
>>> f1, f2 = dgl.unbatch(bhg)
>>> f1
Graph(num_nodes={'user': 2, 'game': 1},
      num_edges={('user', 'plays', 'game'): 2},
      metagraph=[('drug', 'game')])
>>> f2
Graph(num_nodes={'user': 1, 'game': 3},
      num_edges={('user', 'plays', 'game'): 3},
      metagraph=[('drug', 'game')])

See also

batch()