HeteroSubgraphX¶
-
class
dgl.nn.pytorch.explain.
HeteroSubgraphX
(model, num_hops, coef=10.0, high2low=True, num_child=12, num_rollouts=20, node_min=3, shapley_steps=100, log=False)[source]¶ Bases:
torch.nn.modules.module.Module
SubgraphX from On Explainability of Graph Neural Networks via Subgraph Explorations, adapted for heterogeneous graphs
It identifies the most important subgraph from the original graph that plays a critical role in GNN-based graph classification.
It employs Monte Carlo tree search (MCTS) in efficiently exploring different subgraphs for explanation and uses Shapley values as the measure of subgraph importance.
- Parameters
model (nn.Module) –
The GNN model to explain that tackles multiclass graph classification
Its forward function must have the form
forward(self, graph, nfeat)
.The output of its forward function is the logits.
num_hops (int) – Number of message passing layers in the model
coef (float, optional) – This hyperparameter controls the trade-off between exploration and exploitation. A higher value encourages the algorithm to explore relatively unvisited nodes. Default: 10.0
high2low (bool, optional) – If True, it will use the “High2low” strategy for pruning actions, expanding children nodes from high degree to low degree when extending the children nodes in the search tree. Otherwise, it will use the “Low2high” strategy. Default: True
num_child (int, optional) – This is the number of children nodes to expand when extending the children nodes in the search tree. Default: 12
num_rollouts (int, optional) – This is the number of rollouts for MCTS. Default: 20
node_min (int, optional) – This is the threshold to define a leaf node based on the number of nodes in a subgraph. Default: 3
shapley_steps (int, optional) – This is the number of steps for Monte Carlo sampling in estimating Shapley values. Default: 100
log (bool, optional) – If True, it will log the progress. Default: False
-
explain_graph
(graph, feat, target_class, **kwargs)[source]¶ Find the most important subgraph from the original graph for the model to classify the graph into the target class.
- Parameters
graph (DGLGraph) – A heterogeneous graph
feat (dict[str, Tensor]) – The dictionary that associates input node features (values) with the respective node types (keys) present in the graph. The input features are of shape \((N_t, D_t)\). \(N_t\) is the number of nodes for node type \(t\), and \(D_t\) is the feature size for node type \(t\)
target_class (int) – The target class to explain
kwargs (dict) – Additional arguments passed to the GNN model
- Returns
The dictionary associating tensor node ids (values) to node types (keys) that represents the most important subgraph
- Return type
Examples
>>> import dgl >>> import dgl.function as fn >>> import torch as th >>> import torch.nn as nn >>> import torch.nn.functional as F >>> from dgl.nn import HeteroSubgraphX
>>> class Model(nn.Module): ... def __init__(self, in_dim, num_classes, canonical_etypes): ... super(Model, self).__init__() ... self.etype_weights = nn.ModuleDict( ... { ... "_".join(c_etype): nn.Linear(in_dim, num_classes) ... for c_etype in canonical_etypes ... } ... ) ... ... def forward(self, graph, feat): ... with graph.local_scope(): ... c_etype_func_dict = {} ... for c_etype in graph.canonical_etypes: ... src_type, etype, dst_type = c_etype ... wh = self.etype_weights["_".join(c_etype)](feat[src_type]) ... graph.nodes[src_type].data[f"h_{c_etype}"] = wh ... c_etype_func_dict[c_etype] = ( ... fn.copy_u(f"h_{c_etype}", "m"), ... fn.mean("m", "h"), ... ) ... graph.multi_update_all(c_etype_func_dict, "sum") ... hg = 0 ... for ntype in graph.ntypes: ... if graph.num_nodes(ntype): ... hg = hg + dgl.mean_nodes(graph, "h", ntype=ntype) ... return hg
>>> input_dim = 5 >>> num_classes = 2 >>> g = dgl.heterograph({("user", "plays", "game"): ([0, 1, 1, 2], [0, 0, 1, 1])}) >>> g.nodes["user"].data["h"] = th.randn(g.num_nodes("user"), input_dim) >>> g.nodes["game"].data["h"] = th.randn(g.num_nodes("game"), input_dim)
>>> transform = dgl.transforms.AddReverse() >>> g = transform(g)
>>> # define and train the model >>> model = Model(input_dim, num_classes, g.canonical_etypes) >>> feat = g.ndata["h"] >>> optimizer = th.optim.Adam(model.parameters()) >>> for epoch in range(10): ... logits = model(g, feat) ... loss = F.cross_entropy(logits, th.tensor([1])) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Explain for the graph >>> explainer = HeteroSubgraphX(model, num_hops=1) >>> explainer.explain_graph(g, feat, target_class=1) {'game': tensor([0, 1]), 'user': tensor([1, 2])}
-
forward
(*input: Any) → None¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.