#### Note

**Definition**: a :class:`~dgl.batched_graph.BatchedDGLGraph` is a\n :class:`~dgl.DGLGraph` that unions a list of :class:`~dgl.DGLGraph`\\ s. \n\n - The union includes all the nodes,\n edges, and their features. The order of nodes, edges and features are\n preserved. \n\n - Given that we have $V_i$ nodes for graph\n $\\mathcal{G}_i$, the node ID $j$ in graph\n $\\mathcal{G}_i$ correspond to node ID\n $j + \\sum_{k=1}^{i-1} V_k$ in the batched graph. \n\n - Therefore, performing feature transformation and message passing on\n ``BatchedDGLGraph`` is equivalent to doing those\n on all ``DGLGraph`` constituents in parallel. \n\n - Duplicate references to the same graph are\n treated as deep copies; the nodes, edges, and features are duplicated,\n and mutation on one reference does not affect the other. \n - Currently, ``BatchedDGLGraph`` is immutable in\n graph structure (i.e. one can't add\n nodes and edges to it). We need to support mutable batched graphs in\n (far) future. \n - The ``BatchedDGLGraph`` keeps track of the meta\n information of the constituents so it can be\n :func:`~dgl.batched_graph.unbatch`\\ ed to list of ``DGLGraph``\\ s.

\n\nFor more details about the :class:`~dgl.batched_graph.BatchedDGLGraph`\nmodule in DGL, you can click the class name.\n\nStep 2: Tree-LSTM Cell with message-passing APIs\n------------------------------------------------\n\nThe authors proposed two types of Tree LSTM: Child-Sum\nTree-LSTMs, and $N$-ary Tree-LSTMs. In this tutorial we focus \non applying *Binary* Tree-LSTM to binarized constituency trees(this \napplication is also known as *Constituency Tree-LSTM*). We use PyTorch \nas our backend framework to set up the network.\n\nIn `N`-ary Tree LSTM, each unit at node $j$ maintains a hidden\nrepresentation $h_j$ and a memory cell $c_j$. The unit\n$j$ takes the input vector $x_j$ and the hidden\nrepresentations of the their child units: $h_{jl}, 1\\leq l\\leq N$ as\ninput, then update its new hidden representation $h_j$ and memory\ncell $c_j$ by: \n\n\\begin{align}i_j & = & \\sigma\\left(W^{(i)}x_j + \\sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\\right), & (1)\\\\\n f_{jk} & = & \\sigma\\left(W^{(f)}x_j + \\sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \\right), & (2)\\\\\n o_j & = & \\sigma\\left(W^{(o)}x_j + \\sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \\right), & (3) \\\\\n u_j & = & \\textrm{tanh}\\left(W^{(u)}x_j + \\sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \\right), & (4)\\\\\n c_j & = & i_j \\odot u_j + \\sum_{l=1}^{N} f_{jl} \\odot c_{jl}, &(5) \\\\\n h_j & = & o_j \\cdot \\textrm{tanh}(c_j), &(6) \\\\\\end{align}\n\nIt can be decomposed into three phases: ``message_func``,\n``reduce_func`` and ``apply_node_func``.\n\n#### Note

``apply_node_func`` is a new node UDF we have not introduced before. In\n ``apply_node_func``, user specifies what to do with node features,\n without considering edge features and messages. In Tree-LSTM case,\n ``apply_node_func`` is a must, since there exists (leaf) nodes with\n $0$ incoming edges, which would not be updated via\n ``reduce_func``.

\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch as th\nimport torch.nn as nn\n\nclass TreeLSTMCell(nn.Module):\n def __init__(self, x_size, h_size):\n super(TreeLSTMCell, self).__init__()\n self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)\n self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)\n self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))\n self.U_f = nn.Linear(2 * h_size, 2 * h_size)\n\n def message_func(self, edges):\n return {'h': edges.src['h'], 'c': edges.src['c']}\n\n def reduce_func(self, nodes):\n # concatenate h_jl for equation (1), (2), (3), (4)\n h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)\n # equation (2)\n f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())\n # second term of equation (5)\n c = th.sum(f * nodes.mailbox['c'], 1)\n return {'iou': self.U_iou(h_cat), 'c': c}\n\n def apply_node_func(self, nodes):\n # equation (1), (3), (4)\n iou = nodes.data['iou'] + self.b_iou\n i, o, u = th.chunk(iou, 3, 1)\n i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)\n # equation (5)\n c = i * u + nodes.data['c']\n # equation (6)\n h = o * th.tanh(c)\n return {'h' : h, 'c' : c}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Step 3: define traversal\n------------------------\n\nAfter defining the message passing functions, we then need to induce the\nright order to trigger them. This is a significant departure from models\nsuch as GCN, where all nodes are pulling messages from upstream ones\n*simultaneously*.\n\nIn the case of Tree-LSTM, messages start from leaves of the tree, and\npropagate/processed upwards until they reach the roots. A visualization\nis as follows:\n\n.. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif\n :alt:\n\nDGL defines a generator to perform the topological sort, each item is a\ntensor recording the nodes from bottom level to the roots. One can\nappreciate the degree of parallelism by inspecting the difference of the\nfollowings:\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"print('Traversing one tree:')\nprint(dgl.topological_nodes_generator(a_tree))\n\nprint('Traversing many trees at the same time:')\nprint(dgl.topological_nodes_generator(graph))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We then call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import dgl.function as fn\nimport torch as th\n\ngraph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)\ngraph.register_message_func(fn.copy_src('a', 'a'))\ngraph.register_reduce_func(fn.sum('a', 'a'))\n\ntraversal_order = dgl.topological_nodes_generator(graph)\ngraph.prop_nodes(traversal_order)\n\n# the following is a syntax sugar that does the same\n# dgl.prop_nodes_topo(graph)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Note

Before we call :meth:`~dgl.DGLGraph.prop_nodes`, we must specify a\n `message_func` and `reduce_func` in advance, here we use built-in\n copy-from-source and sum function as our message function and reduce\n function for demonstration.

\n\nPutting it together\n-------------------\n\nHere is the complete code that specifies the ``Tree-LSTM`` class:\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"class TreeLSTM(nn.Module):\n def __init__(self,\n num_vocabs,\n x_size,\n h_size,\n num_classes,\n dropout,\n pretrained_emb=None):\n super(TreeLSTM, self).__init__()\n self.x_size = x_size\n self.embedding = nn.Embedding(num_vocabs, x_size)\n if pretrained_emb is not None:\n print('Using glove')\n self.embedding.weight.data.copy_(pretrained_emb)\n self.embedding.weight.requires_grad = True\n self.dropout = nn.Dropout(dropout)\n self.linear = nn.Linear(h_size, num_classes)\n self.cell = TreeLSTMCell(x_size, h_size)\n\n def forward(self, batch, h, c):\n \"\"\"Compute tree-lstm prediction given a batch.\n\n Parameters\n ----------\n batch : dgl.data.SSTBatch\n The data batch.\n h : Tensor\n Initial hidden state.\n c : Tensor\n Initial cell state.\n\n Returns\n -------\n logits : Tensor\n The prediction of each node.\n \"\"\"\n g = batch.graph\n g.register_message_func(self.cell.message_func)\n g.register_reduce_func(self.cell.reduce_func)\n g.register_apply_node_func(self.cell.apply_node_func)\n # feed embedding\n embeds = self.embedding(batch.wordid * batch.mask)\n g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)\n g.ndata['h'] = h\n g.ndata['c'] = c\n # propagate\n dgl.prop_nodes_topo(g)\n # compute logits\n h = self.dropout(g.ndata.pop('h'))\n logits = self.linear(h)\n return logits"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Main Loop\n---------\n\nFinally, we could write a training paradigm in PyTorch:\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\nimport torch.nn.functional as F\n\ndevice = th.device('cpu')\n# hyper parameters\nx_size = 256\nh_size = 256\ndropout = 0.5\nlr = 0.05\nweight_decay = 1e-4\nepochs = 10\n\n# create the model\nmodel = TreeLSTM(trainset.num_vocabs,\n x_size,\n h_size,\n trainset.num_classes,\n dropout)\nprint(model)\n\n# create the optimizer\noptimizer = th.optim.Adagrad(model.parameters(),\n lr=lr,\n weight_decay=weight_decay)\n\ndef batcher(dev):\n def batcher_dev(batch):\n batch_trees = dgl.batch(batch)\n return SSTBatch(graph=batch_trees,\n mask=batch_trees.ndata['mask'].to(device),\n wordid=batch_trees.ndata['x'].to(device),\n label=batch_trees.ndata['y'].to(device))\n return batcher_dev\n\ntrain_loader = DataLoader(dataset=tiny_sst,\n batch_size=5,\n collate_fn=batcher(device),\n shuffle=False,\n num_workers=0)\n\n# training loop\nfor epoch in range(epochs):\n for step, batch in enumerate(train_loader):\n g = batch.graph\n n = g.number_of_nodes()\n h = th.zeros((n, h_size))\n c = th.zeros((n, h_size))\n logits = model(batch, h, c)\n logp = F.log_softmax(logits, 1)\n loss = F.nll_loss(logp, batch.label, reduction='sum') \n optimizer.zero_grad()\n loss.backward()\n optimizer.step()\n pred = th.argmax(logits, 1)\n acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)\n print(\"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |\".format(\n epoch, step, loss.item(), acc))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To train the model on full dataset with different settings(CPU/GPU,\netc.), please refer to our repo's\n`example