"""
.. _model-gat:
Understand Graph Attention Network
==================================
**Authors:** `Hao Zhang `_, `Mufei Li
`_, `Minjie Wang
`_ `Zheng Zhang
`_
From `Graph Convolutional Network (GCN) `_,
we learned that combining local graph structure and node-level features yields
good performance on node classification task. However, the way GCN aggregates
is structure-dependent, which may hurt its generalizability.
One workaround is to simply average over all neighbor node features as in
`GraphSAGE
`_.
`Graph Attention Network `_ proposes an
alternative way by weighting neighbor features with feature dependent and
structure free normalization, in the style of attention.
The goal of this tutorial:
* Explain what is Graph Attention Network.
* Demonstrate how it can be implemented in DGL.
* Understand the attentions learnt.
* Introduce to inductive learning.
"""
###############################################################
# Introducing Attention to GCN
# ----------------------------
#
# The key difference between GAT and GCN is how the information from the one-hop neighborhood is aggregated.
#
# For GCN, a graph convolution operation produces the normalized sum of the node features of neighbors:
#
#
# .. math::
#
# h_i^{(l+1)}=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\frac{1}{c_{ij}} W^{(l)}h^{(l)}_j}\right)
#
#
# where :math:`\mathcal{N}(i)` is the set of its one-hop neighbors (to include
# :math:`v_i` in the set, simply add a self-loop to each node),
# :math:`c_{ij}=\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}` is a
# normalization constant based on graph structure, :math:`\sigma` is an
# activation function (GCN uses ReLU), and :math:`W^{(l)}` is a shared
# weight matrix for node-wise feature transformation. Another model proposed in
# `GraphSAGE
# `_
# employs the same update rule except that they set
# :math:`c_{ij}=|\mathcal{N}(i)|`.
#
# GAT introduces the attention mechanism as a substitute for the statically
# normalized convolution operation. Below are the equations to compute the node
# embedding :math:`h_i^{(l+1)}` of layer :math:`l+1` from the embeddings of
# layer :math:`l`:
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/gat.png
# :width: 450px
# :align: center
#
# .. math::
#
# \begin{align}
# z_i^{(l)}&=W^{(l)}h_i^{(l)},&(1) \\
# e_{ij}^{(l)}&=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}||z_j^{(l)})),&(2)\\
# \alpha_{ij}^{(l)}&=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},&(3)\\
# h_i^{(l+1)}&=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),&(4)
# \end{align}
#
#
# Explanations:
#
#
# * Equation (1) is a linear transformation of the lower layer embedding :math:`h_i^{(l)}`
# and :math:`W^{(l)}` is its learnable weight matrix.
# * Equation (2) computes a pair-wise *unnormalized* attention score between two neighbors.
# Here, it first concatenates the :math:`z` embeddings of the two nodes, where :math:`||`
# denotes concatenation, then takes a dot product of it and a learnable weight vector
# :math:`\vec a^{(l)}`, and applies a LeakyReLU in the end. This form of attention is
# usually called *additive attention*, contrast with the dot-product attention in the
# Transformer model.
# * Equation (3) applies a softmax to normalize the attention scores on each node's
# in-coming edges.
# * Equation (4) is similar to GCN. The embeddings from neighbors are aggregated together,
# scaled by the attention scores.
#
# There are other details from the paper, such as dropout and skip connections.
# For the purpose of simplicity, we omit them in this tutorial and leave the
# link to the full example at the end for interested readers.
#
# In its essence, GAT is just a different aggregation function with attention
# over features of neighbors, instead of a simple mean aggregation.
#
# GAT in DGL
# ----------
#
# Let's first have an overall impression about how a ``GATLayer`` module is
# implemented in DGL. Don't worry, we will break down the four equations above
# one-by-one.
import torch
import torch.nn as nn
import torch.nn.functional as F
class GATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim):
super(GATLayer, self).__init__()
self.g = g
# equation (1)
self.fc = nn.Linear(in_dim, out_dim, bias=False)
# equation (2)
self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e': F.leaky_relu(a)}
def message_func(self, edges):
# message UDF for equation (3) & (4)
return {'z': edges.src['z'], 'e': edges.data['e']}
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h': h}
def forward(self, h):
# equation (1)
z = self.fc(h)
self.g.ndata['z'] = z
# equation (2)
self.g.apply_edges(self.edge_attention)
# equation (3) & (4)
self.g.update_all(self.message_func, self.reduce_func)
return self.g.ndata.pop('h')
##################################################################
# Equation (1)
# ^^^^^^^^^^^^
#
# .. math::
#
# z_i^{(l)}=W^{(l)}h_i^{(l)},(1)
#
# The first one is simple. Linear transformation is very common and can be
# easily implemented in Pytorch using ``torch.nn.Linear``.
#
# Equation (2)
# ^^^^^^^^^^^^
#
# .. math::
#
# e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}|z_j^{(l)})),(2)
#
# The unnormalized attention score :math:`e_{ij}` is calculated using the
# embeddings of adjacent nodes :math:`i` and :math:`j`. This suggests that the
# attention scores can be viewed as edge data which can be calculated by the
# ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**,
# which is defined as below:
def edge_attention(self, edges):
# edge UDF for equation (2)
z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
a = self.attn_fc(z2)
return {'e' : F.leaky_relu(a)}
########################################################################3
# Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}`
# is implemented again using pytorch's linear transformation ``attn_fc``. Note
# that ``apply_edges`` will **batch** all the edge data in one tensor, so the
# ``cat``, ``attn_fc`` here are applied on all the edges in parallel.
#
# Equation (3) & (4)
# ^^^^^^^^^^^^^^^^^^
#
# .. math::
#
# \begin{align}
# \alpha_{ij}^{(l)}&=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},&(3)\\
# h_i^{(l+1)}&=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),&(4)
# \end{align}
#
# Similar to GCN, ``update_all`` API is used to trigger message passing on all
# the nodes. The message function sends out two tensors: the transformed ``z``
# embedding of the source node and the unnormalized attention score ``e`` on
# each edge. The reduce function then performs two tasks:
#
#
# * Normalize the attention scores using softmax (equation (3)).
# * Aggregate neighbor embeddings weighted by the attention scores (equation(4)).
#
# Both tasks first fetch data from the mailbox and then manipulate it on the
# second dimension (``dim=1``), on which the messages are batched.
def reduce_func(self, nodes):
# reduce UDF for equation (3) & (4)
# equation (3)
alpha = F.softmax(nodes.mailbox['e'], dim=1)
# equation (4)
h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
return {'h' : h}
#####################################################################
# Multi-head Attention
# ^^^^^^^^^^^^^^^^^^^^
#
# Analogous to multiple channels in ConvNet, GAT introduces **multi-head
# attention** to enrich the model capacity and to stabilize the learning
# process. Each attention head has its own parameters and their outputs can be
# merged in two ways:
#
# .. math:: \text{concatenation}: h^{(l+1)}_{i} =||_{k=1}^{K}\sigma\left(\sum_{j\in \mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)
#
# or
#
# .. math:: \text{average}: h_{i}^{(l+1)}=\sigma\left(\frac{1}{K}\sum_{k=1}^{K}\sum_{j\in\mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)
#
# where :math:`K` is the number of heads. The authors suggest using
# concatenation for intermediary layers and average for the final layer.
#
# We can use the above defined single-head ``GATLayer`` as the building block
# for the ``MultiHeadGATLayer`` below:
class MultiHeadGATLayer(nn.Module):
def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
super(MultiHeadGATLayer, self).__init__()
self.heads = nn.ModuleList()
for i in range(num_heads):
self.heads.append(GATLayer(g, in_dim, out_dim))
self.merge = merge
def forward(self, h):
head_outs = [attn_head(h) for attn_head in self.heads]
if self.merge == 'cat':
# concat on the output feature dimension (dim=1)
return torch.cat(head_outs, dim=1)
else:
# merge using average
return torch.mean(torch.stack(head_outs))
###########################################################################
# Put everything together
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# Now, we can define a two-layer GAT model:
class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
super(GAT, self).__init__()
self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
# Be aware that the input dimension is hidden_dim*num_heads since
# multiple head outputs are concatenated together. Also, only
# one attention head in the output layer.
self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
def forward(self, h):
h = self.layer1(h)
h = F.elu(h)
h = self.layer2(h)
return h
#############################################################################
# We then load the cora dataset using DGL's built-in data module.
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
def load_cora_data():
data = citegrh.load_cora()
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.ByteTensor(data.train_mask)
g = DGLGraph(data.graph)
return g, features, labels, mask
##############################################################################
# The training loop is exactly the same as in the GCN tutorial.
import time
import numpy as np
g, features, labels, mask = load_cora_data()
# create the model, 2 heads, each head has hidden size 8
net = GAT(g,
in_dim=features.size()[1],
hidden_dim=8,
out_dim=7,
num_heads=2)
# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# main loop
dur = []
for epoch in range(30):
if epoch >= 3:
t0 = time.time()
logits = net(features)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[mask], labels[mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch >= 3:
dur.append(time.time() - t0)
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), np.mean(dur)))
#########################################################################
# Visualizing and Understanding Attention Learnt
# ----------------------------------------------
#
# Cora
# ^^^^
#
# The following table summarizes the model performances on Cora reported in
# `the GAT paper `_ and obtained with dgl
# implementations.
#
# .. list-table::
# :header-rows: 1
#
# * - Model
# - Accuracy
# * - GCN (paper)
# - :math:`81.4\pm 0.5%`
# * - GCN (dgl)
# - :math:`82.05\pm 0.33%`
# * - GAT (paper)
# - :math:`83.0\pm 0.7%`
# * - GAT (dgl)
# - :math:`83.69\pm 0.529%`
#
# *What kind of attention distribution has our model learnt?*
#
# Because the attention weight :math:`a_{ij}` is associated with edges, we can
# visualize it by coloring edges. Below we pick a subgraph of Cora and plot the
# attention weights of the last ``GATLayer``. The nodes are colored according
# to their labels, whereas the edges are colored according to the magnitude of
# the attention weights, which can be referred with the colorbar on the right.
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/cora-attention.png
# :width: 600px
# :align: center
#
# You can that the model seems to learn different attention weights. To
# understand the distribution more thoroughly, we measure the `entropy
# `_) of the
# attention distribution. For any node :math:`i`,
# :math:`\{\alpha_{ij}\}_{j\in\mathcal{N}(i)}` forms a discrete probability
# distribution over all its neighbors with the entropy given by
#
# .. math:: H({\alpha_{ij}}_{j\in\mathcal{N}(i)})=-\sum_{j\in\mathcal{N}(i)} \alpha_{ij}\log\alpha_{ij}
#
# Intuitively, a low entropy means a high degree of concentration, and vice
# versa; an entropy of 0 means all attention is on one source node. The uniform
# distribution has the highest entropy of :math:`\log(\mathcal{N}(i))`.
# Ideally, we want to see the model learns a distribution of lower entropy
# (i.e, one or two neighbors are much more important than the others).
#
# Note that since nodes can have different degrees, the maximum entropy will
# also be different. Therefore, we plot the aggregated histogram of entropy
# values of all nodes in the entire graph. Below are the attention histogram of
# learned by each attention head.
#
# |image2|
#
# As a reference, here is the histogram if all the nodes have uniform attention weight distribution.
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/cora-attention-uniform-hist.png
# :width: 250px
# :align: center
#
# One can see that **the attention values learned is quite similar to uniform distribution**
# (i.e, all neighbors are equally important). This partially
# explains why the performance of GAT is close to that of GCN on Cora
# (according to `author's reported result
# `_, the accuracy difference averaged
# over 100 runs is less than 2%); attention does not matter
# since it does not differentiate much any ways.
#
# *Does that mean the attention mechanism is not useful?* No! A different
# dataset exhibits an entirely different pattern, as we show next.
#
# Protein-Protein Interaction (PPI) networks
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# The PPI dataset used here consists of :math:`24` graphs corresponding to
# different human tissues. Nodes can have up to :math:`121` kinds of labels, so
# the label of node is represented as a binary tensor of size :math:`121`. The
# task is to predict node label.
#
# We use :math:`20` graphs for training, :math:`2` for validation and :math:`2`
# for test. The average number of nodes per graph is :math:`2372`. Each node
# has :math:`50` features that are composed of positional gene sets, motif gene
# sets and immunological signatures. Critically, test graphs remain completely
# unobserved during training, a setting called "inductive learning".
#
# We compare the performance of GAT and GCN for :math:`10` random runs on this
# task and use hyperparameter search on the validation set to find the best
# model.
#
# .. list-table::
# :header-rows: 1
#
# * - Model
# - F1 Score(micro)
# * - GAT
# - :math:`0.975 \pm 0.006`
# * - GCN
# - :math:`0.509 \pm 0.025`
# * - Paper
# - :math:`0.973 \pm 0.002`
#
# The table above is the result of this experiment, where we use micro `F1
# score `_ to evaluate the model
# performance.
#
# .. note::
#
# Below is the calculation process of F1 score:
#
# .. math::
#
# precision=\frac{\sum_{t=1}^{n}TP_{t}}{\sum_{t=1}^{n}(TP_{t} +FP_{t})}
#
# recall=\frac{\sum_{t=1}^{n}TP_{t}}{\sum_{t=1}^{n}(TP_{t} +FN_{t})}
#
# F1_{micro}=2\frac{precision*recall}{precision+recall}
#
# * :math:`TP_{t}` represents for number of nodes that both have and are predicted to have label :math:`t`
# * :math:`FP_{t}` represents for number of nodes that do not have but are predicted to have label :math:`t`
# * :math:`FN_{t}` represents for number of output classes labeled as :math:`t` but predicted as others.
# * :math:`n` is the number of labels, i.e. :math:`121` in our case.
#
# During training, we use ``BCEWithLogitsLoss`` as the loss function. The
# learning curves of GAT and GCN are presented below; what is evident is the
# dramatic performance adavantage of GAT over GCN.
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/ppi-curve.png
# :width: 300px
# :align: center
#
# As before, we can have a statistical understanding of the attentions learnt
# by showing the histogram plot for the node-wise attention entropy. Below are
# the attention histogram learnt by different attention layers.
#
# *Attention learnt in layer 1:*
#
# |image5|
#
# *Attention learnt in layer 2:*
#
# |image6|
#
# *Attention learnt in final layer:*
#
# |image7|
#
# Again, comparing with uniform distribution:
#
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/ppi-uniform-hist.png
# :width: 250px
# :align: center
#
# Clearly, **GAT does learn sharp attention weights**! There is a clear pattern
# over the layers as well: **the attention gets sharper with higher
# layer**.
#
# Unlike the Cora dataset where GAT's gain is lukewarm at best, for PPI there
# is a significant performance gap between GAT and other GNN variants compared
# in `the GAT paper `_ (at least 20%),
# and the attention distributions between the two clearly differ. While this
# deserves further research, one immediate conclusion is that GAT's advantage
# lies perhaps more in its ability to handle a graph with more complex
# neighborhood structure.
#
# What's Next?
# ------------
#
# So far, we demonstrated how to use DGL to implement GAT. There are some
# missing details such as dropout, skip connections and hyper-parameter tuning,
# which are common practices and do not involve DGL-related concepts. We refer
# interested readers to the full example.
#
# * See the optimized full example `here `_.
# * Stay tune for our next tutorial about how to speedup GAT models by parallelizing multiple attention heads and SPMV optimization.
#
# .. |image2| image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/cora-attention-hist.png
# .. |image5| image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/ppi-first-layer-hist.png
# .. |image6| image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/ppi-second-layer-hist.png
# .. |image7| image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/ppi-final-layer-hist.png