dgl.udf.EdgeBatch.batch_size

EdgeBatch.batch_size()[source]

Return the number of edges in the batch.

Returns

Return type

int

Examples

The following example uses PyTorch backend.

>>> import dgl
>>> import torch
>>> # Instantiate a graph
>>> g = dgl.graph((torch.tensor([0, 1, 1]), torch.tensor([1, 1, 0])))
>>> # Define a UDF that returns one for each edge
>>> def edge_udf(edges):
>>>     return {'h': torch.ones(edges.batch_size(), 1)}
>>> # Creates a feature 'h'
>>> g.apply_edges(edge_udf)
>>> g.edata['h']
tensor([[1.],
        [1.],
        [1.]])
>>> # Use edge UDF in message passing
>>> import dgl.function as fn
>>> g.update_all(edge_udf, fn.sum('h', 'h'))
>>> g.ndata['h']
tensor([[1.],
        [2.]])