PGExplainerΒΆ
-
class
dgl.nn.pytorch.explain.
PGExplainer
(model, num_features, num_hops=None, explain_graph=True, coff_budget=0.01, coff_connect=0.0005, sample_bias=0.0)[source]ΒΆ Bases:
torch.nn.modules.module.Module
PGExplainer from Parameterized Explainer for Graph Neural Network <https://arxiv.org/pdf/2011.04573>
PGExplainer adopts a deep neural network (explanation network) to parameterize the generation process of explanations, which enables it to explain multiple instances collectively. PGExplainer models the underlying structure as edge distributions, from which the explanatory graph is sampled.
- Parameters
model (nn.Module) β
The GNN model to explain that tackles multiclass graph classification
Its forward function must have the form
forward(self, graph, nfeat, embed, edge_weight)
.The output of its forward function is the logits if embed=False else the intermediate node embeddings.
num_features (int) β Node embedding size used by
model
.num_hops (int, optional) β The number of hops for GNN information aggregation, which must match the number of message passing layers employed by the GNN to be explained.
explain_graph (bool, optional) β Whether to initialize the model for graph-level or node-level predictions.
coff_budget (float, optional) β Size regularization to constrain the explanation size. Default: 0.01.
coff_connect (float, optional) β Entropy regularization to constrain the connectivity of explanation. Default: 5e-4.
sample_bias (float, optional) β Some members of a population are systematically more likely to be selected in a sample than others. Default: 0.0.
-
explain_graph
(graph, feat, temperature=1.0, training=False, **kwargs)[source]ΒΆ Learn and return an edge mask that plays a crucial role to explain the prediction made by the GNN for a graph. Also, return the prediction made with the edges chosen based on the edge mask.
- Parameters
graph (DGLGraph) β A homogeneous graph.
feat (Tensor) β The input feature of shape \((N, D)\). \(N\) is the number of nodes, and \(D\) is the feature size.
temperature (float) β The temperature parameter fed to the sampling procedure.
training (bool) β Training the explanation network.
kwargs (dict) β Additional arguments passed to the GNN model.
- Returns
Tensor β Classification probabilities given the masked graph. It is a tensor of shape \((B, L)\), where \(L\) is the different types of label in the dataset, and \(B\) is the batch size.
Tensor β Edge weights which is a tensor of shape \((E)\), where \(E\) is the number of edges in the graph. A higher weight suggests a larger contribution of the edge.
Examples
>>> import torch as th >>> import torch.nn as nn >>> import dgl >>> from dgl.data import GINDataset >>> from dgl.dataloading import GraphDataLoader >>> from dgl.nn import GraphConv, PGExplainer >>> import numpy as np
>>> # Define the model >>> class Model(nn.Module): ... def __init__(self, in_feats, out_feats): ... super().__init__() ... self.conv = GraphConv(in_feats, out_feats) ... self.fc = nn.Linear(out_feats, out_feats) ... nn.init.xavier_uniform_(self.fc.weight) ... ... def forward(self, g, h, embed=False, edge_weight=None): ... h = self.conv(g, h, edge_weight=edge_weight) ... ... if embed: ... return h ... ... with g.local_scope(): ... g.ndata['h'] = h ... hg = dgl.mean_nodes(g, 'h') ... return self.fc(hg)
>>> # Load dataset >>> data = GINDataset('MUTAG', self_loop=True) >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model >>> feat_size = data[0][0].ndata['attr'].shape[1] >>> model = Model(feat_size, data.gclasses) >>> criterion = nn.CrossEntropyLoss() >>> optimizer = th.optim.Adam(model.parameters(), lr=1e-2) >>> for bg, labels in dataloader: ... preds = model(bg, bg.ndata['attr']) ... loss = criterion(preds, labels) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = PGExplainer(model, data.gclasses)
>>> # Train the explainer >>> # Define explainer temperature parameter >>> init_tmp, final_tmp = 5.0, 1.0 >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01) >>> for epoch in range(20): ... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20)) ... for bg, labels in dataloader: ... loss = explainer.train_step(bg, bg.ndata['attr'], tmp) ... optimizer_exp.zero_grad() ... loss.backward() ... optimizer_exp.step()
>>> # Explain the prediction for graph 0 >>> graph, l = data[0] >>> graph_feat = graph.ndata.pop("attr") >>> probs, edge_weight = explainer.explain_graph(graph, graph_feat)
-
explain_node
(nodes, graph, feat, temperature=1.0, training=False, **kwargs)[source]ΒΆ Learn and return an edge mask that plays a crucial role to explain the prediction made by the GNN for provided set of node IDs. Also, return the prediction made with the graph and edge mask.
- Parameters
nodes (int, iterable[int], tensor) β The nodes from the graph, which cannot have any duplicate value.
graph (DGLGraph) β A homogeneous graph.
feat (Tensor) β The input feature of shape \((N, D)\). \(N\) is the number of nodes, and \(D\) is the feature size.
temperature (float) β The temperature parameter fed to the sampling procedure.
training (bool) β Training the explanation network.
kwargs (dict) β Additional arguments passed to the GNN model.
- Returns
Tensor β Classification probabilities given the masked graph. It is a tensor of shape \((N, L)\), where \(L\) is the different types of node labels in the dataset, and \(N\) is the number of nodes in the graph.
Tensor β Edge weights which is a tensor of shape \((E)\), where \(E\) is the number of edges in the graph. A higher weight suggests a larger contribution of the edge.
DGLGraph β The batched set of subgraphs induced on the k-hop in-neighborhood of the input center nodes.
Tensor β The new IDs of the subgraph center nodes.
Examples
>>> import dgl >>> import numpy as np >>> import torch
>>> # Define the model >>> class Model(torch.nn.Module): ... def __init__(self, in_feats, out_feats): ... super().__init__() ... self.conv1 = dgl.nn.GraphConv(in_feats, out_feats) ... self.conv2 = dgl.nn.GraphConv(out_feats, out_feats) ... ... def forward(self, g, h, embed=False, edge_weight=None): ... h = self.conv1(g, h, edge_weight=edge_weight) ... if embed: ... return h ... return self.conv2(g, h)
>>> # Load dataset >>> data = dgl.data.CoraGraphDataset(verbose=False) >>> g = data[0] >>> features = g.ndata["feat"] >>> labels = g.ndata["label"]
>>> # Train the model >>> model = Model(features.shape[1], data.num_classes) >>> criterion = torch.nn.CrossEntropyLoss() >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) >>> for epoch in range(20): ... logits = model(g, features) ... loss = criterion(logits, labels) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = dgl.nn.PGExplainer( ... model, data.num_classes, num_hops=2, explain_graph=False ... )
>>> # Train the explainer >>> # Define explainer temperature parameter >>> init_tmp, final_tmp = 5.0, 1.0 >>> optimizer_exp = torch.optim.Adam(explainer.parameters(), lr=0.01) >>> epochs = 10 >>> for epoch in range(epochs): ... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / epochs)) ... loss = explainer.train_step_node(g.nodes(), g, features, tmp) ... optimizer_exp.zero_grad() ... loss.backward() ... optimizer_exp.step()
>>> # Explain the prediction for graph 0 >>> probs, edge_weight, bg, inverse_indices = explainer.explain_node( ... 0, g, features ... )
-
forward
(*input: Any) → NoneΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
train_step
(graph, feat, temperature, **kwargs)[source]ΒΆ Compute the loss of the explanation network for graph classification
- Parameters
graph (DGLGraph) β Input batched homogeneous graph.
feat (Tensor) β The input feature of shape \((N, D)\). \(N\) is the number of nodes, and \(D\) is the feature size.
temperature (float) β The temperature parameter fed to the sampling procedure.
kwargs (dict) β Additional arguments passed to the GNN model.
- Returns
A scalar tensor representing the loss.
- Return type
Tensor
-
train_step_node
(nodes, graph, feat, temperature, **kwargs)[source]ΒΆ Compute the loss of the explanation network for node classification
- Parameters
nodes (int, iterable[int], tensor) β The nodes from the graph used to train the explanation network, which cannot have any duplicate value.
graph (DGLGraph) β Input homogeneous graph.
feat (Tensor) β The input feature of shape \((N, D)\). \(N\) is the number of nodes, and \(D\) is the feature size.
temperature (float) β The temperature parameter fed to the sampling procedure.
kwargs (dict) β Additional arguments passed to the GNN model.
- Returns
A scalar tensor representing the loss.
- Return type
Tensor