HeteroGNNExplainer¶
-
class
dgl.nn.pytorch.explain.
HeteroGNNExplainer
(model, num_hops, lr=0.01, num_epochs=100, *, alpha1=0.005, alpha2=1.0, beta1=1.0, beta2=0.1, log=True)[source]¶ Bases:
torch.nn.modules.module.Module
GNNExplainer model from GNNExplainer: Generating Explanations for Graph Neural Networks, adapted for heterogeneous graphs
It identifies compact subgraph structures and small subsets of node features that play a critical role in GNN-based node classification and graph classification.
To generate an explanation, it learns an edge mask \(M\) and a feature mask \(F\) by optimizing the following objective function.
\[l(y, \hat{y}) + \alpha_1 \|M\|_1 + \alpha_2 H(M) + \beta_1 \|F\|_1 + \beta_2 H(F)\]where \(l\) is the loss function, \(y\) is the original model prediction, \(\hat{y}\) is the model prediction with the edge and feature mask applied, \(H\) is the entropy function.
- Parameters
model (nn.Module) –
The GNN model to explain.
The required arguments of its forward function are graph and feat. The latter one is for input node features.
It should also optionally take an eweight argument for edge weights and multiply the messages by it in message passing.
The output of its forward function is the logits for the predicted node/graph classes.
See also the example in
explain_node()
andexplain_graph()
.num_hops (int) – The number of hops for GNN information aggregation.
lr (float, optional) – The learning rate to use, default to 0.01.
num_epochs (int, optional) – The number of epochs to train.
alpha1 (float, optional) – A higher value will make the explanation edge masks more sparse by decreasing the sum of the edge mask.
alpha2 (float, optional) – A higher value will make the explanation edge masks more sparse by decreasing the entropy of the edge mask.
beta1 (float, optional) – A higher value will make the explanation node feature masks more sparse by decreasing the mean of the node feature mask.
beta2 (float, optional) – A higher value will make the explanation node feature masks more sparse by decreasing the entropy of the node feature mask.
log (bool, optional) – If True, it will log the computation process, default to True.
-
explain_graph
(graph, feat, **kwargs)[source]¶ Learn and return node feature masks and edge masks that play a crucial role to explain the prediction made by the GNN for a graph.
- Parameters
graph (DGLGraph) – A heterogeneous graph that will be explained.
feat (dict[str, Tensor]) – The dictionary that associates input node features (values) with the respective node types (keys) present in the graph. The input features are of shape \((N_t, D_t)\). \(N_t\) is the number of nodes for node type \(t\), and \(D_t\) is the feature size for node type \(t\)
kwargs (dict) – Additional arguments passed to the GNN model.
- Returns
feat_mask (dict[str, Tensor]) – The dictionary that associates the learned node feature importance masks (values) with the respective node types (keys). The masks are of shape \((D_t)\), where \(D_t\) is the node feature size for node type
t
. The values are within range \((0, 1)\). The higher, the more important.edge_mask (dict[Tuple[str], Tensor]) – The dictionary that associates the learned edge importance masks (values) with the respective canonical edge types (keys). The masks are of shape \((E_t)\), where \(E_t\) is the number of edges for canonical edge type \(t\) in the graph. The values are within range \((0, 1)\). The higher, the more important.
Examples
>>> import dgl >>> import dgl.function as fn >>> import torch as th >>> import torch.nn as nn >>> import torch.nn.functional as F >>> from dgl.nn import HeteroGNNExplainer
>>> class Model(nn.Module): ... def __init__(self, in_dim, num_classes, canonical_etypes): ... super(Model, self).__init__() ... self.etype_weights = nn.ModuleDict({ ... '_'.join(c_etype): nn.Linear(in_dim, num_classes) ... for c_etype in canonical_etypes ... }) ... ... def forward(self, graph, feat, eweight=None): ... with graph.local_scope(): ... c_etype_func_dict = {} ... for c_etype in graph.canonical_etypes: ... src_type, etype, dst_type = c_etype ... wh = self.etype_weights['_'.join(c_etype)](feat[src_type]) ... graph.nodes[src_type].data[f'h_{c_etype}'] = wh ... if eweight is None: ... c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'), ... fn.mean('m', 'h')) ... else: ... graph.edges[c_etype].data['w'] = eweight[c_etype] ... c_etype_func_dict[c_etype] = ( ... fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h')) ... graph.multi_update_all(c_etype_func_dict, 'sum') ... hg = 0 ... for ntype in graph.ntypes: ... if graph.num_nodes(ntype): ... hg = hg + dgl.mean_nodes(graph, 'h', ntype=ntype) ... return hg
>>> input_dim = 5 >>> num_classes = 2 >>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])}) >>> g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim) >>> g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim)
>>> transform = dgl.transforms.AddReverse() >>> g = transform(g)
>>> # define and train the model >>> model = Model(input_dim, num_classes, g.canonical_etypes) >>> feat = g.ndata['h'] >>> optimizer = th.optim.Adam(model.parameters()) >>> for epoch in range(10): ... logits = model(g, feat) ... loss = F.cross_entropy(logits, th.tensor([1])) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Explain for the graph >>> explainer = HeteroGNNExplainer(model, num_hops=1) >>> feat_mask, edge_mask = explainer.explain_graph(g, feat) >>> feat_mask {'game': tensor([0.2684, 0.2597, 0.3135, 0.2976, 0.2607]), 'user': tensor([0.2216, 0.2908, 0.2644, 0.2738, 0.2663])} >>> edge_mask {('game', 'rev_plays', 'user'): tensor([0.8922, 0.1966, 0.8371, 0.1330]), ('user', 'plays', 'game'): tensor([0.1785, 0.1696, 0.8065, 0.2167])}
-
explain_node
(ntype, node_id, graph, feat, **kwargs)[source]¶ Learn and return node feature masks and a subgraph that play a crucial role to explain the prediction made by the GNN for node
node_id
of typentype
.It requires
model
to return a dictionary mapping node types to type-specific predictions.- Parameters
ntype (str) – The type of the node to explain.
model
must be trained to make predictions for this particular node type.node_id (int) – The ID of the node to explain.
graph (DGLGraph) – A heterogeneous graph.
feat (dict[str, Tensor]) – The dictionary that associates input node features (values) with the respective node types (keys) present in the graph. The input features are of shape \((N_t, D_t)\). \(N_t\) is the number of nodes for node type \(t\), and \(D_t\) is the feature size for node type \(t\)
kwargs (dict) – Additional arguments passed to the GNN model.
- Returns
new_node_id (Tensor) – The new ID of the input center node.
sg (DGLGraph) – The subgraph induced on the k-hop in-neighborhood of the input center node.
feat_mask (dict[str, Tensor]) – The dictionary that associates the learned node feature importance masks (values) with the respective node types (keys). The masks are of shape \((D_t)\), where \(D_t\) is the node feature size for node type
t
. The values are within range \((0, 1)\). The higher, the more important.edge_mask (dict[Tuple[str], Tensor]) – The dictionary that associates the learned edge importance masks (values) with the respective canonical edge types (keys). The masks are of shape \((E_t)\), where \(E_t\) is the number of edges for canonical edge type \(t\) in the subgraph. The values are within range \((0, 1)\). The higher, the more important.
Examples
>>> import dgl >>> import dgl.function as fn >>> import torch as th >>> import torch.nn as nn >>> import torch.nn.functional as F >>> from dgl.nn import HeteroGNNExplainer
>>> class Model(nn.Module): ... def __init__(self, in_dim, num_classes, canonical_etypes): ... super(Model, self).__init__() ... self.etype_weights = nn.ModuleDict({ ... '_'.join(c_etype): nn.Linear(in_dim, num_classes) ... for c_etype in canonical_etypes ... }) ... ... def forward(self, graph, feat, eweight=None): ... with graph.local_scope(): ... c_etype_func_dict = {} ... for c_etype in graph.canonical_etypes: ... src_type, etype, dst_type = c_etype ... wh = self.etype_weights['_'.join(c_etype)](feat[src_type]) ... graph.nodes[src_type].data[f'h_{c_etype}'] = wh ... if eweight is None: ... c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'), ... fn.mean('m', 'h')) ... else: ... graph.edges[c_etype].data['w'] = eweight[c_etype] ... c_etype_func_dict[c_etype] = ( ... fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h')) ... graph.multi_update_all(c_etype_func_dict, 'sum') ... return graph.ndata['h']
>>> input_dim = 5 >>> num_classes = 2 >>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])}) >>> g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim) >>> g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim)
>>> transform = dgl.transforms.AddReverse() >>> g = transform(g)
>>> # define and train the model >>> model = Model(input_dim, num_classes, g.canonical_etypes) >>> feat = g.ndata['h'] >>> optimizer = th.optim.Adam(model.parameters()) >>> for epoch in range(10): ... logits = model(g, feat)['user'] ... loss = F.cross_entropy(logits, th.tensor([1, 1, 1])) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Explain the prediction for node 0 of type 'user' >>> explainer = HeteroGNNExplainer(model, num_hops=1) >>> new_center, sg, feat_mask, edge_mask = explainer.explain_node('user', 0, g, feat) >>> new_center tensor([0]) >>> sg Graph(num_nodes={'game': 1, 'user': 1}, num_edges={('game', 'rev_plays', 'user'): 1, ('user', 'plays', 'game'): 1, ('user', 'rev_rev_plays', 'game'): 1}, metagraph=[('game', 'user', 'rev_plays'), ('user', 'game', 'plays'), ('user', 'game', 'rev_rev_plays')]) >>> feat_mask {'game': tensor([0.2348, 0.2780, 0.2611, 0.2513, 0.2823]), 'user': tensor([0.2716, 0.2450, 0.2658, 0.2876, 0.2738])} >>> edge_mask {('game', 'rev_plays', 'user'): tensor([0.0630]), ('user', 'plays', 'game'): tensor([0.1939]), ('user', 'rev_rev_plays', 'game'): tensor([0.9166])}
-
forward
(*input: Any) → None¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.