"""Torch Module for PGExplainer"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .... import batch, ETYPE, khop_in_subgraph, NID, to_homogeneous
__all__ = ["PGExplainer", "HeteroPGExplainer"]
[docs]class PGExplainer(nn.Module):
r"""PGExplainer from `Parameterized Explainer for Graph Neural Network
<https://arxiv.org/pdf/2011.04573>`
PGExplainer adopts a deep neural network (explanation network) to
parameterize the generation process of explanations, which enables it to
explain multiple instances collectively. PGExplainer models the underlying
structure as edge distributions, from which the explanatory graph is
sampled.
Parameters
----------
model : nn.Module
The GNN model to explain that tackles multiclass graph classification
* Its forward function must have the form
:attr:`forward(self, graph, nfeat, embed, edge_weight)`.
* The output of its forward function is the logits if embed=False else
the intermediate node embeddings.
num_features : int
Node embedding size used by :attr:`model`.
num_hops : int, optional
The number of hops for GNN information aggregation, which must match the
number of message passing layers employed by the GNN to be explained.
explain_graph : bool, optional
Whether to initialize the model for graph-level or node-level predictions.
coff_budget : float, optional
Size regularization to constrain the explanation size. Default: 0.01.
coff_connect : float, optional
Entropy regularization to constrain the connectivity of explanation. Default: 5e-4.
sample_bias : float, optional
Some members of a population are systematically more likely to be selected
in a sample than others. Default: 0.0.
"""
def __init__(
self,
model,
num_features,
num_hops=None,
explain_graph=True,
coff_budget=0.01,
coff_connect=5e-4,
sample_bias=0.0,
):
super(PGExplainer, self).__init__()
self.model = model
self.graph_explanation = explain_graph
# Node explanation requires additional self-embedding data.
self.num_features = num_features * (2 if self.graph_explanation else 3)
self.num_hops = num_hops
# training hyperparameters for PGExplainer
self.coff_budget = coff_budget
self.coff_connect = coff_connect
self.sample_bias = sample_bias
self.init_bias = 0.0
# Explanation network in PGExplainer
self.elayers = nn.Sequential(
nn.Linear(self.num_features, 64), nn.ReLU(), nn.Linear(64, 1)
)
def set_masks(self, graph, edge_mask=None):
r"""Set the edge mask that plays a crucial role to explain the
prediction made by the GNN for a graph. Initialize learnable edge
mask if it is None.
Parameters
----------
graph : DGLGraph
A homogeneous graph.
edge_mask : Tensor, optional
Learned importance mask of the edges in the graph, which is a tensor
of shape :math:`(E)`, where :math:`E` is the number of edges in the
graph. The values are within range :math:`(0, 1)`. The higher,
the more important. Default: None.
"""
if edge_mask is None:
num_nodes = graph.num_nodes()
num_edges = graph.num_edges()
init_bias = self.init_bias
std = nn.init.calculate_gain("relu") * math.sqrt(
2.0 / (2 * num_nodes)
)
self.edge_mask = torch.randn(num_edges) * std + init_bias
else:
self.edge_mask = edge_mask
self.edge_mask = self.edge_mask.to(graph.device)
def clear_masks(self):
r"""Clear the edge mask that play a crucial role to explain the
prediction made by the GNN for a graph.
"""
self.edge_mask = None
def parameters(self):
r"""
Returns an iterator over the `Parameter` objects of the `nn.Linear`
layers in the `self.elayers` sequential module. Each `Parameter`
object contains the weight and bias parameters of an `nn.Linear`
layer, as learned during training.
Returns
-------
iterator
An iterator over the `Parameter` objects of the `nn.Linear`
layers in the `self.elayers` sequential module.
"""
return self.elayers.parameters()
def loss(self, prob, ori_pred):
r"""The loss function that is used to learn the edge
distribution.
Parameters
----------
prob: Tensor
Tensor contains a set of probabilities for each possible
class label of some model for all the batched graphs,
which is of shape :math:`(B, L)`, where :math:`L` is the
different types of label in the dataset and :math:`B` is
the batch size.
ori_pred: Tensor
Tensor of shape :math:`(B, 1)`, representing the original prediction
for the graph, where :math:`B` is the batch size.
Returns
-------
float
The function that returns the sum of the three loss components,
which is a scalar tensor representing the total loss.
"""
target_prob = prob.gather(-1, ori_pred.unsqueeze(-1))
# 1e-6 added to prob to avoid taking the logarithm of zero
target_prob += 1e-6
# computing the log likelihood for a single prediction
pred_loss = torch.mean(-torch.log(target_prob))
# size
edge_mask = self.sparse_mask_values
if self.coff_budget <= 0:
size_loss = self.coff_budget * torch.sum(edge_mask)
else:
size_loss = self.coff_budget * F.relu(
torch.sum(edge_mask) - self.coff_budget
)
# entropy
scale = 0.99
edge_mask = self.edge_mask * (2 * scale - 1.0) + (1.0 - scale)
mask_ent = -edge_mask * torch.log(edge_mask) - (
1 - edge_mask
) * torch.log(1 - edge_mask)
mask_ent_loss = self.coff_connect * torch.mean(mask_ent)
loss = pred_loss + size_loss + mask_ent_loss
return loss
def concrete_sample(self, w, beta=1.0, training=True):
r"""Sample from the instantiation of concrete distribution when training.
Parameters
----------
w : Tensor
A tensor representing the log of the prior probability of choosing the edges.
beta : float, optional
Controls the degree of randomness in the output of the sigmoid function.
training : bool, optional
Randomness is injected during training.
Returns
-------
Tensor
If training is set to True, the output is a tensor of probabilities that
represent the probability of activating the gate for each input element.
If training is set to False, the output is also a tensor of probabilities,
but they are determined solely by the log_alpha values, without adding any
random noise.
"""
if training:
bias = self.sample_bias
random_noise = torch.rand(w.size()).to(w.device)
random_noise = bias + (1 - 2 * bias) * random_noise
gate_inputs = torch.log(random_noise) - torch.log(
1.0 - random_noise
)
gate_inputs = (gate_inputs + w) / beta
gate_inputs = torch.sigmoid(gate_inputs)
else:
gate_inputs = torch.sigmoid(w)
return gate_inputs
[docs] def train_step(self, graph, feat, temperature, **kwargs):
r"""Compute the loss of the explanation network for graph classification
Parameters
----------
graph : DGLGraph
Input batched homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
temperature : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
A scalar tensor representing the loss.
"""
assert (
self.graph_explanation
), '"explain_graph" must be True when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
pred = self.model(graph, feat, embed=False, **kwargs)
pred = pred.argmax(-1).data
prob, _ = self.explain_graph(
graph, feat, temperature, training=True, **kwargs
)
loss = self.loss(prob, pred)
return loss
[docs] def train_step_node(self, nodes, graph, feat, temperature, **kwargs):
r"""Compute the loss of the explanation network for node classification
Parameters
----------
nodes : int, iterable[int], tensor
The nodes from the graph used to train the explanation network,
which cannot have any duplicate value.
graph : DGLGraph
Input homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
temperature : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
A scalar tensor representing the loss.
"""
assert (
not self.graph_explanation
), '"explain_graph" must be False when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
if isinstance(nodes, torch.Tensor):
nodes = nodes.tolist()
if isinstance(nodes, int):
nodes = [nodes]
prob, _, batched_graph, inverse_indices = self.explain_node(
nodes, graph, feat, temperature, training=True, **kwargs
)
pred = self.model(
batched_graph, self.batched_feats, embed=False, **kwargs
)
pred = pred.argmax(-1).data
loss = self.loss(prob[inverse_indices], pred[inverse_indices])
return loss
[docs] def explain_graph(
self, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for a graph. Also, return
the prediction made with the edges chosen based on the edge mask.
Parameters
----------
graph : DGLGraph
A homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
temperature : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
Classification probabilities given the masked graph. It is a tensor
of shape :math:`(B, L)`, where :math:`L` is the different types of
label in the dataset, and :math:`B` is the batch size.
Tensor
Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
is the number of edges in the graph. A higher weight suggests a
larger contribution of the edge.
Examples
--------
>>> import torch as th
>>> import torch.nn as nn
>>> import dgl
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import GraphConv, PGExplainer
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
... def __init__(self, in_feats, out_feats):
... super().__init__()
... self.conv = GraphConv(in_feats, out_feats)
... self.fc = nn.Linear(out_feats, out_feats)
... nn.init.xavier_uniform_(self.fc.weight)
...
... def forward(self, g, h, embed=False, edge_weight=None):
... h = self.conv(g, h, edge_weight=edge_weight)
...
... if embed:
... return h
...
... with g.local_scope():
... g.ndata['h'] = h
... hg = dgl.mean_nodes(g, 'h')
... return self.fc(hg)
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = th.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
... preds = model(bg, bg.ndata['attr'])
... loss = criterion(preds, labels)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = PGExplainer(model, data.gclasses)
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
... for bg, labels in dataloader:
... loss = explainer.train_step(bg, bg.ndata['attr'], tmp)
... optimizer_exp.zero_grad()
... loss.backward()
... optimizer_exp.step()
>>> # Explain the prediction for graph 0
>>> graph, l = data[0]
>>> graph_feat = graph.ndata.pop("attr")
>>> probs, edge_weight = explainer.explain_graph(graph, graph_feat)
"""
assert (
self.graph_explanation
), '"explain_graph" must be True when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
embed = self.model(graph, feat, embed=True, **kwargs)
embed = embed.data
col, row = graph.edges()
col_emb = embed[col.long()]
row_emb = embed[row.long()]
emb = torch.cat([col_emb, row_emb], dim=-1)
emb = self.elayers(emb)
values = emb.reshape(-1)
values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values
reverse_eids = graph.edge_ids(row, col).long()
edge_mask = (values + values[reverse_eids]) / 2
self.set_masks(graph, edge_mask)
# the model prediction with the updated edge mask
logits = self.model(graph, feat, edge_weight=self.edge_mask, **kwargs)
probs = F.softmax(logits, dim=-1)
if training:
probs = probs.data
else:
self.clear_masks()
return (probs, edge_mask)
[docs] def explain_node(
self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for provided set of node IDs.
Also, return the prediction made with the graph and edge mask.
Parameters
----------
nodes : int, iterable[int], tensor
The nodes from the graph, which cannot have any duplicate value.
graph : DGLGraph
A homogeneous graph.
feat : Tensor
The input feature of shape :math:`(N, D)`. :math:`N` is the
number of nodes, and :math:`D` is the feature size.
temperature : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
Classification probabilities given the masked graph. It is a tensor
of shape :math:`(N, L)`, where :math:`L` is the different types
of node labels in the dataset, and :math:`N` is the number of nodes
in the graph.
Tensor
Edge weights which is a tensor of shape :math:`(E)`, where :math:`E`
is the number of edges in the graph. A higher weight suggests a
larger contribution of the edge.
DGLGraph
The batched set of subgraphs induced on the k-hop in-neighborhood
of the input center nodes.
Tensor
The new IDs of the subgraph center nodes.
Examples
--------
>>> import dgl
>>> import numpy as np
>>> import torch
>>> # Define the model
>>> class Model(torch.nn.Module):
... def __init__(self, in_feats, out_feats):
... super().__init__()
... self.conv1 = dgl.nn.GraphConv(in_feats, out_feats)
... self.conv2 = dgl.nn.GraphConv(out_feats, out_feats)
...
... def forward(self, g, h, embed=False, edge_weight=None):
... h = self.conv1(g, h, edge_weight=edge_weight)
... if embed:
... return h
... return self.conv2(g, h)
>>> # Load dataset
>>> data = dgl.data.CoraGraphDataset(verbose=False)
>>> g = data[0]
>>> features = g.ndata["feat"]
>>> labels = g.ndata["label"]
>>> # Train the model
>>> model = Model(features.shape[1], data.num_classes)
>>> criterion = torch.nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for epoch in range(20):
... logits = model(g, features)
... loss = criterion(logits, labels)
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.PGExplainer(
... model, data.num_classes, num_hops=2, explain_graph=False
... )
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = torch.optim.Adam(explainer.parameters(), lr=0.01)
>>> epochs = 10
>>> for epoch in range(epochs):
... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / epochs))
... loss = explainer.train_step_node(g.nodes(), g, features, tmp)
... optimizer_exp.zero_grad()
... loss.backward()
... optimizer_exp.step()
>>> # Explain the prediction for graph 0
>>> probs, edge_weight, bg, inverse_indices = explainer.explain_node(
... 0, g, features
... )
"""
assert (
not self.graph_explanation
), '"explain_graph" must be False when initializing the module.'
assert (
self.num_hops is not None
), '"num_hops" must be provided when initializing the module.'
if isinstance(nodes, torch.Tensor):
nodes = nodes.tolist()
if isinstance(nodes, int):
nodes = [nodes]
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
batched_graph = []
batched_embed = []
for node_id in nodes:
sg, inverse_indices = khop_in_subgraph(
graph, node_id, self.num_hops
)
sg.ndata["feat"] = feat[sg.ndata[NID].long()]
sg.ndata["train"] = torch.tensor(
[nid in inverse_indices for nid in sg.nodes()], device=sg.device
)
embed = self.model(sg, sg.ndata["feat"], embed=True, **kwargs)
embed = embed.data
col, row = sg.edges()
col_emb = embed[col.long()]
row_emb = embed[row.long()]
self_emb = embed[inverse_indices[0]].repeat(sg.num_edges(), 1)
emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)
batched_embed.append(emb)
batched_graph.append(sg)
batched_graph = batch(batched_graph)
batched_embed = torch.cat(batched_embed)
batched_embed = self.elayers(batched_embed)
values = batched_embed.reshape(-1)
values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values
col, row = batched_graph.edges()
reverse_eids = batched_graph.edge_ids(row, col).long()
edge_mask = (values + values[reverse_eids]) / 2
self.set_masks(batched_graph, edge_mask)
batched_feats = batched_graph.ndata["feat"]
# the model prediction with the updated edge mask
logits = self.model(
batched_graph, batched_feats, edge_weight=self.edge_mask, **kwargs
)
probs = F.softmax(logits, dim=-1)
batched_inverse_indices = (
batched_graph.ndata["train"].nonzero().squeeze(1)
)
if training:
self.batched_feats = batched_feats
probs = probs.data
else:
self.clear_masks()
return (
probs,
edge_mask,
batched_graph,
batched_inverse_indices,
)
[docs]class HeteroPGExplainer(PGExplainer):
r"""PGExplainer from `Parameterized Explainer for Graph Neural Network
<https://arxiv.org/pdf/2011.04573>`__, adapted for heterogeneous graphs
PGExplainer adopts a deep neural network (explanation network) to
parameterize the generation process of explanations, which enables it to
explain multiple instances collectively. PGExplainer models the underlying
structure as edge distributions, from which the explanatory graph is
sampled.
Parameters
----------
model : nn.Module
The GNN model to explain that tackles multiclass graph classification
* Its forward function must have the form
:attr:`forward(self, graph, nfeat, embed, edge_weight)`.
* The output of its forward function is the logits if embed=False else
the intermediate node embeddings.
num_features : int
Node embedding size used by :attr:`model`.
coff_budget : float, optional
Size regularization to constrain the explanation size. Default: 0.01.
coff_connect : float, optional
Entropy regularization to constrain the connectivity of explanation. Default: 5e-4.
sample_bias : float, optional
Some members of a population are systematically more likely to be selected
in a sample than others. Default: 0.0.
"""
[docs] def train_step(self, graph, feat, temperature, **kwargs):
# pylint: disable=useless-super-delegation
r"""Compute the loss of the explanation network for graph classification
Parameters
----------
graph : DGLGraph
Input batched heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
the number of nodes for node type :math:`t`, and :math:`D_t` is the
feature size for node type :math:`t`
temperature : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
A scalar tensor representing the loss.
"""
return super().train_step(graph, feat, temperature, **kwargs)
[docs] def train_step_node(self, nodes, graph, feat, temperature, **kwargs):
r"""Compute the loss of the explanation network for node classification
Parameters
----------
nodes : dict[str, Iterable[int]]
A dict mapping node types (keys) to an iterable set of node ids (values).
graph : DGLGraph
Input heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
the number of nodes for node type :math:`t`, and :math:`D_t` is the
feature size for node type :math:`t`
temperature : float
The temperature parameter fed to the sampling procedure.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
A scalar tensor representing the loss.
"""
assert (
not self.graph_explanation
), '"explain_graph" must be False when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
prob, _, batched_graph, inverse_indices = self.explain_node(
nodes, graph, feat, temperature, training=True, **kwargs
)
pred = self.model(
batched_graph, self.batched_feats, embed=False, **kwargs
)
pred = {ntype: pred[ntype].argmax(-1).data for ntype in pred.keys()}
loss = self.loss(
torch.cat(
[prob[ntype][nid] for ntype, nid in inverse_indices.items()]
),
torch.cat(
[pred[ntype][nid] for ntype, nid in inverse_indices.items()]
),
)
return loss
[docs] def explain_graph(
self, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for a graph. Also, return
the prediction made with the edges chosen based on the edge mask.
Parameters
----------
graph : DGLGraph
A heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
the number of nodes for node type :math:`t`, and :math:`D_t` is the
feature size for node type :math:`t`
temperature : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
Tensor
Classification probabilities given the masked graph. It is a tensor
of shape :math:`(B, L)`, where :math:`L` is the different types of
label in the dataset, and :math:`B` is the batch size.
dict[str, Tensor]
A dict mapping edge types (keys) to edge tensors (values) of shape
:math:`(E_t)`, where :math:`E_t` is the number of edges in the graph
for edge type :math:`t`. A higher weight suggests a larger
contribution of the edge.
Examples
--------
>>> import dgl
>>> import torch as th
>>> import torch.nn as nn
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
... def __init__(self, in_feats, hid_feats, out_feats, rel_names):
... super().__init__()
... self.conv = dgl.nn.HeteroGraphConv(
... {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
... aggregate="sum",
... )
... self.fc = nn.Linear(hid_feats, out_feats)
... nn.init.xavier_uniform_(self.fc.weight)
...
... def forward(self, g, h, embed=False, edge_weight=None):
... if edge_weight:
... mod_kwargs = {
... etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
... }
... h = self.conv(g, h, mod_kwargs=mod_kwargs)
... else:
... h = self.conv(g, h)
...
... if embed:
... return h
...
... with g.local_scope():
... g.ndata["h"] = h
... hg = 0
... for ntype in g.ntypes:
... hg = hg + dgl.mean_nodes(g, "h", ntype=ntype)
... return self.fc(hg)
>>> # Load dataset
>>> input_dim = 5
>>> hidden_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
... logits = model(g, g.ndata["h"])
... loss = th.nn.functional.cross_entropy(logits, th.tensor([1]))
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.HeteroPGExplainer(model, hidden_dim)
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
... loss = explainer.train_step(g, g.ndata["h"], tmp)
... optimizer_exp.zero_grad()
... loss.backward()
... optimizer_exp.step()
>>> # Explain the graph
>>> feat = g.ndata.pop("h")
>>> probs, edge_mask = explainer.explain_graph(g, feat)
"""
assert (
self.graph_explanation
), '"explain_graph" must be True when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
embed = self.model(graph, feat, embed=True, **kwargs)
for ntype, emb in embed.items():
graph.nodes[ntype].data["emb"] = emb.data
homo_graph = to_homogeneous(graph, ndata=["emb"])
homo_embed = homo_graph.ndata["emb"]
col, row = homo_graph.edges()
col_emb = homo_embed[col.long()]
row_emb = homo_embed[row.long()]
emb = torch.cat([col_emb, row_emb], dim=-1)
emb = self.elayers(emb)
values = emb.reshape(-1)
values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values
reverse_eids = homo_graph.edge_ids(row, col).long()
edge_mask = (values + values[reverse_eids]) / 2
self.set_masks(homo_graph, edge_mask)
# convert the edge mask back into heterogeneous format
hetero_edge_mask = self._edge_mask_to_heterogeneous(
edge_mask=edge_mask,
homograph=homo_graph,
heterograph=graph,
)
# the model prediction with the updated edge mask
logits = self.model(graph, feat, edge_weight=hetero_edge_mask, **kwargs)
probs = F.softmax(logits, dim=-1)
if training:
probs = probs.data
else:
self.clear_masks()
return (probs, hetero_edge_mask)
[docs] def explain_node(
self, nodes, graph, feat, temperature=1.0, training=False, **kwargs
):
r"""Learn and return an edge mask that plays a crucial role to
explain the prediction made by the GNN for provided set of node IDs.
Also, return the prediction made with the batched graph and edge mask.
Parameters
----------
nodes : dict[str, Iterable[int]]
A dict mapping node types (keys) to an iterable set of node ids (values).
graph : DGLGraph
A heterogeneous graph.
feat : dict[str, Tensor]
A dict mapping node types (keys) to feature tensors (values).
The input features are of shape :math:`(N_t, D_t)`. :math:`N_t` is
the number of nodes for node type :math:`t`, and :math:`D_t` is the
feature size for node type :math:`t`
temperature : float
The temperature parameter fed to the sampling procedure.
training : bool
Training the explanation network.
kwargs : dict
Additional arguments passed to the GNN model.
Returns
-------
dict[str, Tensor]
A dict mapping node types (keys) to classification probabilities
for node labels (values). The values are tensors of shape
:math:`(N_t, L)`, where :math:`L` is the different types of node
labels in the dataset, and :math:`N_t` is the number of nodes in
the graph for node type :math:`t`.
dict[str, Tensor]
A dict mapping edge types (keys) to edge tensors (values) of shape
:math:`(E_t)`, where :math:`E_t` is the number of edges in the graph
for edge type :math:`t`. A higher weight suggests a larger
contribution of the edge.
DGLGraph
The batched set of subgraphs induced on the k-hop in-neighborhood
of the input center nodes.
dict[str, Tensor]
A dict mapping node types (keys) to a tensor of node IDs (values)
which correspond to the subgraph center nodes.
Examples
--------
>>> import dgl
>>> import torch as th
>>> import torch.nn as nn
>>> import numpy as np
>>> # Define the model
>>> class Model(nn.Module):
... def __init__(self, in_feats, hid_feats, out_feats, rel_names):
... super().__init__()
... self.conv = dgl.nn.HeteroGraphConv(
... {rel: dgl.nn.GraphConv(in_feats, hid_feats) for rel in rel_names},
... aggregate="sum",
... )
... self.fc = nn.Linear(hid_feats, out_feats)
... nn.init.xavier_uniform_(self.fc.weight)
...
... def forward(self, g, h, embed=False, edge_weight=None):
... if edge_weight:
... mod_kwargs = {
... etype: {"edge_weight": mask} for etype, mask in edge_weight.items()
... }
... h = self.conv(g, h, mod_kwargs=mod_kwargs)
... else:
... h = self.conv(g, h)
...
... return h
>>> # Load dataset
>>> input_dim = 5
>>> hidden_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim)
>>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, hidden_dim, num_classes, g.canonical_etypes)
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
... logits = model(g, g.ndata["h"])['user']
... loss = th.nn.functional.cross_entropy(logits, th.tensor([1,1,1]))
... optimizer.zero_grad()
... loss.backward()
... optimizer.step()
>>> # Initialize the explainer
>>> explainer = dgl.nn.HeteroPGExplainer(
... model, hidden_dim, num_hops=2, explain_graph=False
... )
>>> # Train the explainer
>>> # Define explainer temperature parameter
>>> init_tmp, final_tmp = 5.0, 1.0
>>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01)
>>> for epoch in range(20):
... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20))
... loss = explainer.train_step_node(
... { ntype: g.nodes(ntype) for ntype in g.ntypes },
... g, g.ndata["h"], tmp
... )
... optimizer_exp.zero_grad()
... loss.backward()
... optimizer_exp.step()
>>> # Explain the graph
>>> feat = g.ndata.pop("h")
>>> probs, edge_mask, bg, inverse_indices = explainer.explain_node(
... { "user": [0] }, g, feat
... )
"""
assert (
not self.graph_explanation
), '"explain_graph" must be False when initializing the module.'
assert (
self.num_hops is not None
), '"num_hops" must be provided when initializing the module.'
self.model = self.model.to(graph.device)
self.elayers = self.elayers.to(graph.device)
batched_embed = []
batched_homo_graph = []
batched_hetero_graph = []
for target_ntype, target_nids in nodes.items():
if isinstance(target_nids, torch.Tensor):
target_nids = target_nids.tolist()
for target_nid in target_nids:
sg, inverse_indices = khop_in_subgraph(
graph, {target_ntype: target_nid}, self.num_hops
)
for sg_ntype in sg.ntypes:
sg_feat = feat[sg_ntype][sg.ndata[NID][sg_ntype].long()]
train_mask = [
sg_ntype in inverse_indices
and node_id in inverse_indices[sg_ntype]
for node_id in sg.nodes(sg_ntype)
]
sg.nodes[sg_ntype].data["feat"] = sg_feat
sg.nodes[sg_ntype].data["train"] = torch.tensor(
train_mask, device=sg.device
)
embed = self.model(sg, sg.ndata["feat"], embed=True, **kwargs)
for ntype in embed.keys():
sg.nodes[ntype].data["emb"] = embed[ntype].data
homo_sg = to_homogeneous(sg, ndata=["emb"])
homo_sg_embed = homo_sg.ndata["emb"]
col, row = homo_sg.edges()
col_emb = homo_sg_embed[col.long()]
row_emb = homo_sg_embed[row.long()]
self_emb = homo_sg_embed[
inverse_indices[target_ntype][0]
].repeat(sg.num_edges(), 1)
emb = torch.cat([col_emb, row_emb, self_emb], dim=-1)
batched_embed.append(emb)
batched_homo_graph.append(homo_sg)
batched_hetero_graph.append(sg)
batched_homo_graph = batch(batched_homo_graph)
batched_hetero_graph = batch(batched_hetero_graph)
batched_embed = torch.cat(batched_embed)
batched_embed = self.elayers(batched_embed)
values = batched_embed.reshape(-1)
values = self.concrete_sample(
values, beta=temperature, training=training
)
self.sparse_mask_values = values
col, row = batched_homo_graph.edges()
reverse_eids = batched_homo_graph.edge_ids(row, col).long()
edge_mask = (values + values[reverse_eids]) / 2
self.set_masks(batched_homo_graph, edge_mask)
# Convert the edge mask back into heterogeneous format.
hetero_edge_mask = self._edge_mask_to_heterogeneous(
edge_mask=edge_mask,
homograph=batched_homo_graph,
heterograph=batched_hetero_graph,
)
batched_feats = {
ntype: batched_hetero_graph.nodes[ntype].data["feat"]
for ntype in batched_hetero_graph.ntypes
}
# The model prediction with the updated edge mask.
logits = self.model(
batched_hetero_graph,
batched_feats,
edge_weight=hetero_edge_mask,
**kwargs,
)
probs = {
ntype: F.softmax(logits[ntype], dim=-1) for ntype in logits.keys()
}
batched_inverse_indices = {
ntype: batched_hetero_graph.nodes[ntype]
.data["train"]
.nonzero()
.squeeze(1)
for ntype in batched_hetero_graph.ntypes
}
if training:
self.batched_feats = batched_feats
probs = {ntype: probs[ntype].data for ntype in probs.keys()}
else:
self.clear_masks()
return (
probs,
hetero_edge_mask,
batched_hetero_graph,
batched_inverse_indices,
)
def _edge_mask_to_heterogeneous(self, edge_mask, homograph, heterograph):
r"""Convert an edge mask from homogeneous mappings built through
embeddings into heterogenous format by leveraging the context from
the source DGLGraphs in homogenous and heterogeneous form.
The `edge_mask` needs to have been built using the embedding of the
homogenous graph format for the mappings to work correctly.
Parameters
----------
edge_mask : dict[str, Tensor]
A dict mapping node types (keys) to a tensor of edge weights (values).
homograph : DGLGraph
The homogeneous form of the source graph.
heterograph : DGLGraph
The heterogeneous form of the source graph.
Returns
-------
dict[str, Tensor]
A dict mapping node types (keys) to tensors of node ids (values)
"""
return {
etype: edge_mask[
(homograph.edata[ETYPE] == heterograph.get_etype_id(etype))
.nonzero()
.squeeze(1)
]
for etype in heterograph.canonical_etypes
}