"""Utilities for batching/unbatching graphs."""
from collections.abc import Mapping
from . import backend as F, convert, utils
from .base import ALL, DGLError, EID, is_all, NID
from .heterograph import DGLGraph
from .heterograph_index import disjoint_union, slice_gidx
__all__ = ["batch", "unbatch", "slice_batch"]
[docs]def batch(graphs, ndata=ALL, edata=ALL):
r"""Batch a collection of :class:`DGLGraph` s into one graph for more efficient
graph computation.
Each input graph becomes one disjoint component of the batched graph. The nodes
and edges are relabeled to be disjoint segments:
================= ========= ================= === =========
graphs[0] graphs[1] ... graphs[k]
================= ========= ================= === =========
Original node ID 0 ~ N_0 0 ~ N_1 ... 0 ~ N_k
New node ID 0 ~ N_0 N_0 ~ N_0+N_1 ... \sum_{i=0}^{k-1} N_i ~
\sum_{i=0}^k N_i
================= ========= ================= === =========
Because of this, many of the computations on a batched graph are the same as if
performed on each graph individually, but become much more efficient
since they can be parallelized easily. This makes ``dgl.batch`` very useful
for tasks dealing with many graph samples such as graph classification tasks.
For heterograph inputs, they must share the same set of relations (i.e., node types
and edge types) and the function will perform batching on each relation one by one.
Thus, the result is also a heterograph and has the same set of relations as the inputs.
The numbers of nodes and edges of the input graphs are accessible via the
:func:`DGLGraph.batch_num_nodes` and :func:`DGLGraph.batch_num_edges` attributes
of the resulting graph. For homogeneous graphs, they are 1D integer tensors,
with each element being the number of nodes/edges of the corresponding input graph. For
heterographs, they are dictionaries of 1D integer tensors, with node
type or edge type as the keys.
The function supports batching batched graphs. The batch size of the result
graph is the sum of the batch sizes of all the input graphs.
By default, node/edge features are batched by concatenating the feature tensors
of all input graphs. This thus requires features of the same name to have
the same data type and feature size. One can pass ``None`` to the ``ndata``
or ``edata`` argument to prevent feature batching, or pass a list of strings
to specify which features to batch.
To unbatch the graph back to a list, use the :func:`dgl.unbatch` function.
Parameters
----------
graphs : list[DGLGraph]
Input graphs.
ndata : list[str], None, optional
Node features to batch.
edata : list[str], None, optional
Edge features to batch.
Returns
-------
DGLGraph
Batched graph.
Examples
--------
Batch homogeneous graphs
>>> import dgl
>>> import torch as th
>>> # 4 nodes, 3 edges
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> # 3 nodes, 4 edges
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> bg = dgl.batch([g1, g2])
>>> bg
Graph(num_nodes=7, num_edges=7,
ndata_schemes={}
edata_schemes={})
>>> bg.batch_size
2
>>> bg.batch_num_nodes()
tensor([4, 3])
>>> bg.batch_num_edges()
tensor([3, 4])
>>> bg.edges()
(tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))
Batch batched graphs
>>> bbg = dgl.batch([bg, bg])
>>> bbg.batch_size
4
>>> bbg.batch_num_nodes()
tensor([4, 3, 4, 3])
>>> bbg.batch_num_edges()
tensor([3, 4, 3, 4])
Batch graphs with feature data
>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)
>>> g1.edata['w'] = th.ones(g1.num_edges(), 2)
>>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)
>>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)
>>> bg = dgl.batch([g1, g2])
>>> bg.ndata['x']
tensor([[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[1, 1, 1],
[1, 1, 1],
[1, 1, 1]])
>>> bg.edata['w']
tensor([[1, 1],
[1, 1],
[1, 1],
[0, 0],
[0, 0],
[0, 0],
[0, 0]])
Batch heterographs
>>> hg1 = dgl.heterograph({
... ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
... ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})
>>> bhg = dgl.batch([hg1, hg2])
>>> bhg
Graph(num_nodes={'user': 3, 'game': 4},
num_edges={('user', 'plays', 'game'): 5},
metagraph=[('drug', 'game')])
>>> bhg.batch_size
2
>>> bhg.batch_num_nodes()
{'user' : tensor([2, 1]), 'game' : tensor([1, 3])}
>>> bhg.batch_num_edges()
{('user', 'plays', 'game') : tensor([2, 3])}
See Also
--------
unbatch
"""
if len(graphs) == 0:
raise DGLError("The input list of graphs cannot be empty.")
if not (is_all(ndata) or isinstance(ndata, list) or ndata is None):
raise DGLError(
"Invalid argument ndata: must be a string list but got {}.".format(
type(ndata)
)
)
if not (is_all(edata) or isinstance(edata, list) or edata is None):
raise DGLError(
"Invalid argument edata: must be a string list but got {}.".format(
type(edata)
)
)
if any(g.is_block for g in graphs):
raise DGLError("Batching a MFG is not supported.")
relations = list(graphs[0].canonical_etypes)
relation_ids = [graphs[0].get_etype_id(r) for r in relations]
ntypes = list(graphs[0].ntypes)
ntype_ids = [graphs[0].get_ntype_id(n) for n in ntypes]
etypes = [etype for _, etype, _ in relations]
gidx = disjoint_union(
graphs[0]._graph.metagraph, [g._graph for g in graphs]
)
retg = DGLGraph(gidx, ntypes, etypes)
# Compute batch num nodes
bnn = {}
for ntype in ntypes:
bnn[ntype] = F.cat([g.batch_num_nodes(ntype) for g in graphs], 0)
retg.set_batch_num_nodes(bnn)
# Compute batch num edges
bne = {}
for etype in relations:
bne[etype] = F.cat([g.batch_num_edges(etype) for g in graphs], 0)
retg.set_batch_num_edges(bne)
# Batch node feature
if ndata is not None:
for ntype_id, ntype in zip(ntype_ids, ntypes):
all_empty = all(g._graph.num_nodes(ntype_id) == 0 for g in graphs)
frames = [
g._node_frames[ntype_id]
for g in graphs
if g._graph.num_nodes(ntype_id) > 0 or all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(
frames, ndata, 'nodes["{}"].data'.format(ntype)
)
retg.nodes[ntype].data.update(ret_feat)
# Batch edge feature
if edata is not None:
for etype_id, etype in zip(relation_ids, relations):
all_empty = all(g._graph.num_edges(etype_id) == 0 for g in graphs)
frames = [
g._edge_frames[etype_id]
for g in graphs
if g._graph.num_edges(etype_id) > 0 or all_empty
]
# TODO: do we require graphs with no nodes/edges to have the same schema? Currently
# we allow empty graphs to have no features during batching.
ret_feat = _batch_feat_dicts(
frames, edata, "edges[{}].data".format(etype)
)
retg.edges[etype].data.update(ret_feat)
return retg
def _batch_feat_dicts(frames, keys, feat_dict_name):
"""Internal function to batch feature dictionaries.
Parameters
----------
frames : list[Frame]
List of frames
keys : list[str]
Feature keys. Can be '__ALL__', meaning batching all features.
feat_dict_name : str
Name of the feature dictionary for reporting errors.
Returns
-------
dict[str, Tensor]
New feature dict.
"""
if len(frames) == 0:
return {}
schemas = [frame.schemes for frame in frames]
# sanity checks
if is_all(keys):
utils.check_all_same_schema(schemas, feat_dict_name)
keys = schemas[0].keys()
else:
utils.check_all_same_schema_for_keys(schemas, keys, feat_dict_name)
# concat features
ret_feat = {k: F.cat([fd[k] for fd in frames], 0) for k in keys}
return ret_feat
[docs]def unbatch(g, node_split=None, edge_split=None):
"""Revert the batch operation by split the given graph into a list of small ones.
This is the reverse operation of :func:``dgl.batch``. If the ``node_split``
or the ``edge_split`` is not given, it calls :func:`DGLGraph.batch_num_nodes`
and :func:`DGLGraph.batch_num_edges` of the input graph to get the information.
If the ``node_split`` or the ``edge_split`` arguments are given,
it will partition the graph according to the given segments. One must assure
that the partition is valid -- edges of the i^th graph only connect nodes
belong to the i^th graph. Otherwise, DGL will throw an error.
The function supports heterograph input, in which case the two split
section arguments shall be of dictionary type -- similar to the
:func:`DGLGraph.batch_num_nodes`
and :func:`DGLGraph.batch_num_edges` attributes of a heterograph.
Parameters
----------
g : DGLGraph
Input graph to unbatch.
node_split : Tensor, dict[str, Tensor], optional
Number of nodes of each result graph.
edge_split : Tensor, dict[str, Tensor], optional
Number of edges of each result graph.
Returns
-------
list[DGLGraph]
Unbatched list of graphs.
Examples
--------
Unbatch a batched graph
>>> import dgl
>>> import torch as th
>>> # 4 nodes, 3 edges
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> # 3 nodes, 4 edges
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> # add features
>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)
>>> g1.edata['w'] = th.ones(g1.num_edges(), 2)
>>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)
>>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)
>>> bg = dgl.batch([g1, g2])
>>> f1, f2 = dgl.unbatch(bg)
>>> f1
Graph(num_nodes=4, num_edges=3,
ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})
>>> f2
Graph(num_nodes=3, num_edges=4,
ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)}
edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})
With provided split arguments:
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> g3 = dgl.graph((th.tensor([0]), th.tensor([1])))
>>> bg = dgl.batch([g1, g2, g3])
>>> bg.batch_num_nodes()
tensor([4, 3, 2])
>>> bg.batch_num_edges()
tensor([3, 4, 1])
>>> # unbatch but merge g2 and g3
>>> f1, f2 = dgl.unbatch(bg, th.tensor([4, 5]), th.tensor([3, 5]))
>>> f1
Graph(num_nodes=4, num_edges=3,
ndata_schemes={}
edata_schemes={})
>>> f2
Graph(num_nodes=5, num_edges=5,
ndata_schemes={}
edata_schemes={})
Heterograph input
>>> hg1 = dgl.heterograph({
... ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
... ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})
>>> bhg = dgl.batch([hg1, hg2])
>>> f1, f2 = dgl.unbatch(bhg)
>>> f1
Graph(num_nodes={'user': 2, 'game': 1},
num_edges={('user', 'plays', 'game'): 2},
metagraph=[('drug', 'game')])
>>> f2
Graph(num_nodes={'user': 1, 'game': 3},
num_edges={('user', 'plays', 'game'): 3},
metagraph=[('drug', 'game')])
See Also
--------
batch
"""
num_split = None
# Parse node_split
if node_split is None:
node_split = {ntype: g.batch_num_nodes(ntype) for ntype in g.ntypes}
elif not isinstance(node_split, Mapping):
if len(g.ntypes) != 1:
raise DGLError(
"Must provide a dictionary for argument node_split when"
" there are multiple node types."
)
node_split = {g.ntypes[0]: node_split}
if node_split.keys() != set(g.ntypes):
raise DGLError("Must specify node_split for each node type.")
for split in node_split.values():
if num_split is not None and num_split != len(split):
raise DGLError(
"All node_split and edge_split must specify the same number"
" of split sizes."
)
num_split = len(split)
# Parse edge_split
if edge_split is None:
edge_split = {
etype: g.batch_num_edges(etype) for etype in g.canonical_etypes
}
elif not isinstance(edge_split, Mapping):
if len(g.etypes) != 1:
raise DGLError(
"Must provide a dictionary for argument edge_split when"
" there are multiple edge types."
)
edge_split = {g.canonical_etypes[0]: edge_split}
if edge_split.keys() != set(g.canonical_etypes):
raise DGLError("Must specify edge_split for each canonical edge type.")
for split in edge_split.values():
if num_split is not None and num_split != len(split):
raise DGLError(
"All edge_split and edge_split must specify the same number"
" of split sizes."
)
num_split = len(split)
node_split = {
k: F.asnumpy(split).tolist() for k, split in node_split.items()
}
edge_split = {
k: F.asnumpy(split).tolist() for k, split in edge_split.items()
}
# Split edges for each relation
edge_dict_per = [{} for i in range(num_split)]
for rel in g.canonical_etypes:
srctype, etype, dsttype = rel
srcnid_off = dstnid_off = 0
u, v = g.edges(order="eid", etype=rel)
us = F.split(u, edge_split[rel], 0)
vs = F.split(v, edge_split[rel], 0)
for i, (subu, subv) in enumerate(zip(us, vs)):
edge_dict_per[i][rel] = (subu - srcnid_off, subv - dstnid_off)
srcnid_off += node_split[srctype][i]
dstnid_off += node_split[dsttype][i]
num_nodes_dict_per = [
{k: split[i] for k, split in node_split.items()}
for i in range(num_split)
]
# Create graphs
gs = [
convert.heterograph(edge_dict, num_nodes_dict, idtype=g.idtype)
for edge_dict, num_nodes_dict in zip(edge_dict_per, num_nodes_dict_per)
]
# Unbatch node features
for ntype in g.ntypes:
for key, feat in g.nodes[ntype].data.items():
subfeats = F.split(feat, node_split[ntype], 0)
for subg, subf in zip(gs, subfeats):
subg.nodes[ntype].data[key] = subf
# Unbatch edge features
for etype in g.canonical_etypes:
for key, feat in g.edges[etype].data.items():
subfeats = F.split(feat, edge_split[etype], 0)
for subg, subf in zip(gs, subfeats):
subg.edges[etype].data[key] = subf
return gs
[docs]def slice_batch(g, gid, store_ids=False):
"""Get a particular graph from a batch of graphs.
Parameters
----------
g : DGLGraph
Input batched graph.
gid : int
The ID of the graph to retrieve.
store_ids : bool
If True, it will store the raw IDs of the extracted nodes and edges in the ``ndata`` and
``edata`` of the resulting graph under name ``dgl.NID`` and ``dgl.EID``, respectively.
Returns
-------
DGLGraph
Retrieved graph.
Examples
--------
The following example uses PyTorch backend.
>>> import dgl
>>> import torch
Create a batched graph.
>>> g1 = dgl.graph(([0, 1], [2, 3]))
>>> g2 = dgl.graph(([1], [2]))
>>> bg = dgl.batch([g1, g2])
Get the second component graph.
>>> g = dgl.slice_batch(bg, 1)
>>> print(g)
Graph(num_nodes=3, num_edges=1,
ndata_schemes={}
edata_schemes={})
"""
start_nid = []
num_nodes = []
for ntype in g.ntypes:
batch_num_nodes = g.batch_num_nodes(ntype)
num_nodes.append(F.as_scalar(batch_num_nodes[gid]))
if gid == 0:
start_nid.append(0)
else:
start_nid.append(
F.as_scalar(F.sum(F.slice_axis(batch_num_nodes, 0, 0, gid), 0))
)
start_eid = []
num_edges = []
for etype in g.canonical_etypes:
batch_num_edges = g.batch_num_edges(etype)
num_edges.append(F.as_scalar(batch_num_edges[gid]))
if gid == 0:
start_eid.append(0)
else:
start_eid.append(
F.as_scalar(F.sum(F.slice_axis(batch_num_edges, 0, 0, gid), 0))
)
# Slice graph structure
gidx = slice_gidx(
g._graph,
utils.toindex(num_nodes),
utils.toindex(start_nid),
utils.toindex(num_edges),
utils.toindex(start_eid),
)
retg = DGLGraph(gidx, g.ntypes, g.etypes)
# Slice node features
for ntid, ntype in enumerate(g.ntypes):
stnid = start_nid[ntid]
for key, feat in g.nodes[ntype].data.items():
subfeats = F.slice_axis(feat, 0, stnid, stnid + num_nodes[ntid])
retg.nodes[ntype].data[key] = subfeats
if store_ids:
retg.nodes[ntype].data[NID] = F.arange(
stnid, stnid + num_nodes[ntid], retg.idtype, retg.device
)
# Slice edge features
for etid, etype in enumerate(g.canonical_etypes):
steid = start_eid[etid]
for key, feat in g.edges[etype].data.items():
subfeats = F.slice_axis(feat, 0, steid, steid + num_edges[etid])
retg.edges[etype].data[key] = subfeats
if store_ids:
retg.edges[etype].data[EID] = F.arange(
steid, steid + num_edges[etid], retg.idtype, retg.device
)
return retg