dgl.DGLGraph.set_batch_num_nodes¶
-
DGLGraph.
set_batch_num_nodes
(val)¶ Manually set the number of nodes for each graph in the batch with the specified node type.
- Parameters
val (Tensor or Mapping[str, Tensor]) – The dictionary storing number of nodes for each graph in the batch for all node types. If the graph has only one node type,
val
can also be a single array indicating the number of nodes per graph in the batch.
Notes
This API is always used together with
set_batch_num_edges
to specify batching information of a graph, it also do not check the correspondance between the graph structure and batching information and user must guarantee there will be no cross-graph edges in the batch.Examples
The following example uses PyTorch backend.
>>> import dgl >>> import torch
Create a homogeneous graph.
>>> g = dgl.graph(([0, 1, 2, 3, 4, 5], [1, 2, 0, 4, 5, 3]))
Manually set batch information
>>> g.set_batch_num_nodes(torch.tensor([3, 3])) >>> g.set_batch_num_edges(torch.tensor([3, 3]))
Unbatch the graph.
>>> dgl.unbatch(g) [Graph(num_nodes=3, num_edges=3, ndata_schemes={} edata_schemes={}), Graph(num_nodes=3, num_edges=3, ndata_schemes={} edata_schemes={})]
Create a heterogeneous graph.
>>> hg = dgl.heterograph({ ... ('user', 'plays', 'game') : ([0, 1, 2, 3, 4, 5], [0, 1, 1, 3, 3, 2]), ... ('developer', 'develops', 'game') : ([0, 1, 2, 3], [1, 0, 3, 2])})
Manually set batch information.
>>> hg.set_batch_num_nodes({ ... 'user': torch.tensor([3, 3]), ... 'game': torch.tensor([2, 2]), ... 'developer': torch.tensor([2, 2])}) >>> hg.set_batch_num_edges({ ... ('user', 'plays', 'game'): torch.tensor([3, 3]), ... ('developer', 'develops', 'game'): torch.tensor([2, 2])})
Unbatch the graph.
>>> g1, g2 = dgl.unbatch(hg) >>> g1 Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3}, num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3}, metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')]) >>> g2 Graph(num_nodes={'developer': 2, 'game': 2, 'user': 3}, num_edges={('developer', 'develops', 'game'): 2, ('user', 'plays', 'game'): 3}, metagraph=[('developer', 'game', 'develops'), ('user', 'game', 'plays')])
See also