dgl.batch¶
-
dgl.
batch
(graphs, ndata='__ALL__', edata='__ALL__')[source]¶ Batch a collection of
DGLGraph
s into one graph for more efficient graph computation.Each input graph becomes one disjoint component of the batched graph. The nodes and edges are relabeled to be disjoint segments:
Original node ID
0 ~ N_0
0 ~ N_1
…
0 ~ N_k
New node ID
0 ~ N_0
N_0+1 ~ N_0+N_1+1
…
1+sum_{i=0}^{k-1} N_i ~ 1+sum_{i=0}^k N_i
Because of this, many of the computations on a batched graph are the same as if performed on each graph individually, but become much more efficient since they can be parallelized easily. This makes
dgl.batch
very useful for tasks dealing with many graph samples such as graph classification tasks.For heterograph inputs, they must share the same set of relations (i.e., node types and edge types) and the function will perform batching on each relation one by one. Thus, the result is also a heterograph and has the same set of relations as the inputs.
The numbers of nodes and edges of the input graphs are accessible via the
DGLGraph.batch_num_nodes()
andDGLGraph.batch_num_edges()
attributes of the resulting graph. For homogeneous graphs, they are 1D integer tensors, with each element being the number of nodes/edges of the corresponding input graph. For heterographs, they are dictionaries of 1D integer tensors, with node type or edge type as the keys.The function supports batching batched graphs. The batch size of the result graph is the sum of the batch sizes of all the input graphs.
By default, node/edge features are batched by concatenating the feature tensors of all input graphs. This thus requires features of the same name to have the same data type and feature size. One can pass
None
to thendata
oredata
argument to prevent feature batching, or pass a list of strings to specify which features to batch.To unbatch the graph back to a list, use the
dgl.unbatch()
function.- Parameters
- Returns
Batched graph.
- Return type
Examples
Batch homogeneous graphs
>>> import dgl >>> import torch as th >>> # 4 nodes, 3 edges >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3]))) >>> # 3 nodes, 4 edges >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0]))) >>> bg = dgl.batch([g1, g2]) >>> bg Graph(num_nodes=7, num_edges=7, ndata_schemes={} edata_schemes={}) >>> bg.batch_size 2 >>> bg.batch_num_nodes() tensor([4, 3]) >>> bg.batch_num_edges() tensor([3, 4]) >>> bg.edges() (tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))
Batch batched graphs
>>> bbg = dgl.batch([bg, bg]) >>> bbg.batch_size 4 >>> bbg.batch_num_nodes() tensor([4, 3, 4, 3]) >>> bbg.batch_num_edges() tensor([3, 4, 3, 4])
Batch graphs with feature data
>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3) >>> g1.edata['w'] = th.ones(g1.num_edges(), 2) >>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3) >>> g2.edata['w'] = th.zeros(g2.num_edges(), 2) >>> bg = dgl.batch([g1, g2]) >>> bg.ndata['x'] tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1]]) >>> bg.edata['w'] tensor([[1, 1], [1, 1], [1, 1], [0, 0], [0, 0], [0, 0], [0, 0]])
Batch heterographs
>>> hg1 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))}) >>> hg2 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))}) >>> bhg = dgl.batch([hg1, hg2]) >>> bhg Graph(num_nodes={'user': 3, 'game': 4}, num_edges={('user', 'plays', 'game'): 5}, metagraph=[('drug', 'game')]) >>> bhg.batch_size 2 >>> bhg.batch_num_nodes() {'user' : tensor([2, 1]), 'game' : tensor([1, 3])} >>> bhg.batch_num_edges() {('user', 'plays', 'game') : tensor([2, 3])}
See also