Note
Click here to download the full example code
Stochastic Training of GNN for Link Prediction¶
This tutorial will show how to train a multi-layer GraphSAGE for link
prediction on ogbn-arxiv
provided by Open Graph Benchmark
(OGB). The dataset
contains around 170 thousand nodes and 1 million edges.
By the end of this tutorial, you will be able to
Train a GNN model for link prediction on a single GPU with DGL’s neighbor sampling components.
This tutorial assumes that you have read the Introduction of Neighbor Sampling for GNN Training and Neighbor Sampling for Node Classification.
Link Prediction Overview¶
Link prediction requires the model to predict the probability of existence of an edge. This tutorial does so by computing a dot product between the representations of both incident nodes.
It then minimizes the following binary cross entropy loss.
This is identical to the link prediction formulation in the previous tutorial on link prediction.
Loading Dataset¶
This tutorial loads the dataset from the ogb
package as in the
previous tutorial.
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset('ogbn-arxiv')
device = 'cpu' # change to 'cuda' for GPU
graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph)
print(graph)
print(node_labels)
node_features = graph.ndata['feat']
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test']
Out:
Graph(num_nodes=169343, num_edges=2332486,
ndata_schemes={'year': Scheme(shape=(1,), dtype=torch.int64), 'feat': Scheme(shape=(128,), dtype=torch.float32)}
edata_schemes={})
tensor([[ 4],
[ 5],
[28],
...,
[10],
[ 4],
[ 1]])
Number of classes: 40
Defining Neighbor Sampler and Data Loader in DGL¶
Different from the link prediction tutorial for full graph, a common practice to train GNN on large graphs is to iterate over the edges in minibatches, since computing the probability of all edges is usually impossible. For each minibatch of edges, you compute the output representation of their incident nodes using neighbor sampling and GNN, in a similar fashion introduced in the large-scale node classification tutorial.
DGL provides dgl.dataloading.as_edge_prediction_sampler
to
iterate over edges for edge classification or link prediction tasks.
To perform link prediction, you need to specify a negative sampler. DGL
provides builtin negative samplers such as
dgl.dataloading.negative_sampler.Uniform
. Here this tutorial uniformly
draws 5 negative examples per positive example.
negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)
After defining the negative sampler, one can then define the edge data
loader with neighbor sampling. To create an DataLoader
for
link prediction, provide a neighbor sampler object as well as the negative
sampler object created above.
sampler = dgl.dataloading.NeighborSampler([4, 4])
sampler = dgl.dataloading.as_edge_prediction_sampler(
sampler, negative_sampler=negative_sampler)
train_dataloader = dgl.dataloading.DataLoader(
# The following arguments are specific to DataLoader.
graph, # The graph
torch.arange(graph.number_of_edges()), # The edges to iterate over
sampler, # The neighbor sampler
device=device, # Put the MFGs on CPU or GPU
# The following arguments are inherited from PyTorch DataLoader.
batch_size=1024, # Batch size
shuffle=True, # Whether to shuffle the nodes for every epoch
drop_last=False, # Whether to drop the last incomplete batch
num_workers=0 # Number of sampler processes
)
You can peek one minibatch from train_dataloader
and see what it
will give you.
input_nodes, pos_graph, neg_graph, mfgs = next(iter(train_dataloader))
print('Number of input nodes:', len(input_nodes))
print('Positive graph # nodes:', pos_graph.number_of_nodes(), '# edges:', pos_graph.number_of_edges())
print('Negative graph # nodes:', neg_graph.number_of_nodes(), '# edges:', neg_graph.number_of_edges())
print(mfgs)
Out:
Number of input nodes: 57093
Positive graph # nodes: 6881 # edges: 1024
Negative graph # nodes: 6881 # edges: 5120
[Block(num_src_nodes=57093, num_dst_nodes=23780, num_edges=88602), Block(num_src_nodes=23780, num_dst_nodes=6881, num_edges=24147)]
The example minibatch consists of four elements.
The first element is an ID tensor for the input nodes, i.e., nodes whose input features are needed on the first GNN layer for this minibatch.
The second element and the third element are the positive graph and the negative graph for this minibatch. The concept of positive and negative graphs have been introduced in the full-graph link prediction tutorial. In minibatch training, the positive graph and the negative graph only contain nodes necessary for computing the pair-wise scores of positive and negative examples in the current minibatch.
The last element is a list of MFGs storing the computation dependencies for each GNN layer. The MFGs are used to compute the GNN outputs of the nodes involved in positive/negative graph.
Defining Model for Node Representation¶
The model is almost identical to the one in the node classification tutorial. The only difference is that since you are doing link prediction, the output dimension will not be the number of classes in the dataset.
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
class Model(nn.Module):
def __init__(self, in_feats, h_feats):
super(Model, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='mean')
self.h_feats = h_feats
def forward(self, mfgs, x):
h_dst = x[:mfgs[0].num_dst_nodes()]
h = self.conv1(mfgs[0], (x, h_dst))
h = F.relu(h)
h_dst = h[:mfgs[1].num_dst_nodes()]
h = self.conv2(mfgs[1], (h, h_dst))
return h
model = Model(num_features, 128).to(device)
Defining the Score Predictor for Edges¶
After getting the node representation necessary for the minibatch, the last thing to do is to predict the score of the edges and non-existent edges in the sampled minibatch.
The following score predictor, copied from the link prediction tutorial, takes a dot product between the incident nodes’ representations.
import dgl.function as fn
class DotPredictor(nn.Module):
def forward(self, g, h):
with g.local_scope():
g.ndata['h'] = h
# Compute a new edge feature named 'score' by a dot-product between the
# source node feature 'h' and destination node feature 'h'.
g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
# u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
return g.edata['score'][:, 0]
Evaluating Performance with Unsupervised Learning (Optional)¶
There are various ways to evaluate the performance of link prediction. This tutorial follows the practice of GraphSAGE paper. Basically, it first trains a GNN via link prediction, and get an embedding for each node. Then it trains a downstream classifier on top of this embedding and compute the accuracy as an assessment of the embedding quality.
To obtain the representations of all the nodes, this tutorial uses neighbor sampling as introduced in the node classification tutorial.
Note
If you would like to obtain node representations without neighbor sampling during inference, please refer to this user guide.
def inference(model, graph, node_features):
with torch.no_grad():
nodes = torch.arange(graph.number_of_nodes())
sampler = dgl.dataloading.NeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader(
graph, torch.arange(graph.number_of_nodes()), sampler,
batch_size=1024,
shuffle=False,
drop_last=False,
num_workers=4,
device=device)
result = []
for input_nodes, output_nodes, mfgs in train_dataloader:
# feature copy from CPU to GPU takes place here
inputs = mfgs[0].srcdata['feat']
result.append(model(mfgs, inputs))
return torch.cat(result)
import sklearn.metrics
def evaluate(emb, label, train_nids, valid_nids, test_nids):
classifier = nn.Linear(emb.shape[1], num_classes).to(device)
opt = torch.optim.LBFGS(classifier.parameters())
def compute_loss():
pred = classifier(emb[train_nids].to(device))
loss = F.cross_entropy(pred, label[train_nids].to(device))
return loss
def closure():
loss = compute_loss()
opt.zero_grad()
loss.backward()
return loss
prev_loss = float('inf')
for i in range(1000):
opt.step(closure)
with torch.no_grad():
loss = compute_loss().item()
if np.abs(loss - prev_loss) < 1e-4:
print('Converges at iteration', i)
break
else:
prev_loss = loss
with torch.no_grad():
pred = classifier(emb.to(device)).cpu()
label = label
valid_acc = sklearn.metrics.accuracy_score(label[valid_nids].numpy(), pred[valid_nids].numpy().argmax(1))
test_acc = sklearn.metrics.accuracy_score(label[test_nids].numpy(), pred[test_nids].numpy().argmax(1))
return valid_acc, test_acc
Defining Training Loop¶
The following initializes the model and defines the optimizer.
The following is the training loop for link prediction and evaluation, and also saves the model that performs the best on the validation set:
import tqdm
import sklearn.metrics
best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(1):
with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, pos_graph, neg_graph, mfgs) in enumerate(tq):
# feature copy from CPU to GPU takes place here
inputs = mfgs[0].srcdata['feat']
outputs = model(mfgs, inputs)
pos_score = predictor(pos_graph, outputs)
neg_score = predictor(neg_graph, outputs)
score = torch.cat([pos_score, neg_score])
label = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
loss = F.binary_cross_entropy_with_logits(score, label)
opt.zero_grad()
loss.backward()
opt.step()
tq.set_postfix({'loss': '%.03f' % loss.item()}, refresh=False)
if (step + 1) % 500 == 0:
model.eval()
emb = inference(model, graph, node_features)
valid_acc, test_acc = evaluate(emb, node_labels, train_nids, valid_nids, test_nids)
print('Epoch {} Validation Accuracy {} Test Accuracy {}'.format(epoch, valid_acc, test_acc))
if best_accuracy < valid_acc:
best_accuracy = valid_acc
torch.save(model.state_dict(), best_model_path)
model.train()
# Note that this tutorial do not train the whole model to the end.
break
Out:
0%| | 0/2278 [00:00<?, ?it/s]/home/ubuntu/.pyenv/versions/miniconda3-latest/lib/python3.7/site-packages/torch/autocast_mode.py:141: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling
warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
0%| | 1/2278 [00:00<08:43, 4.35it/s, loss=34.708]
0%| | 2/2278 [00:00<06:41, 5.67it/s, loss=24.959]
0%| | 3/2278 [00:00<06:29, 5.84it/s, loss=17.747]
0%| | 4/2278 [00:00<06:06, 6.20it/s, loss=12.501]
0%| | 5/2278 [00:00<05:39, 6.69it/s, loss=8.688]
0%| | 6/2278 [00:00<05:51, 6.46it/s, loss=5.992]
0%| | 7/2278 [00:01<05:55, 6.39it/s, loss=4.209]
0%| | 8/2278 [00:01<05:45, 6.57it/s, loss=3.008]
0%| | 9/2278 [00:01<06:10, 6.12it/s, loss=2.285]
0%| | 10/2278 [00:01<06:15, 6.03it/s, loss=1.800]
0%| | 11/2278 [00:01<05:56, 6.35it/s, loss=1.508]
1%| | 12/2278 [00:01<05:40, 6.66it/s, loss=1.337]
1%| | 13/2278 [00:02<05:32, 6.82it/s, loss=1.238]
1%| | 14/2278 [00:02<05:26, 6.92it/s, loss=1.164]
1%| | 15/2278 [00:02<05:26, 6.94it/s, loss=1.095]
1%| | 16/2278 [00:02<05:24, 6.97it/s, loss=1.043]
1%| | 17/2278 [00:02<05:37, 6.69it/s, loss=0.984]
1%| | 18/2278 [00:02<05:39, 6.65it/s, loss=0.915]
1%| | 19/2278 [00:02<05:40, 6.64it/s, loss=0.881]
1%| | 20/2278 [00:03<05:54, 6.37it/s, loss=0.855]
1%| | 21/2278 [00:03<05:41, 6.61it/s, loss=0.825]
1%| | 22/2278 [00:03<05:34, 6.74it/s, loss=0.798]
1%|1 | 23/2278 [00:03<05:38, 6.65it/s, loss=0.782]
1%|1 | 24/2278 [00:03<05:34, 6.74it/s, loss=0.773]
1%|1 | 25/2278 [00:03<05:35, 6.71it/s, loss=0.766]
1%|1 | 26/2278 [00:03<05:27, 6.88it/s, loss=0.746]
1%|1 | 27/2278 [00:04<05:24, 6.94it/s, loss=0.753]
1%|1 | 28/2278 [00:04<05:40, 6.61it/s, loss=0.750]
1%|1 | 29/2278 [00:04<05:35, 6.71it/s, loss=0.738]
1%|1 | 30/2278 [00:04<05:41, 6.58it/s, loss=0.740]
1%|1 | 31/2278 [00:04<05:46, 6.48it/s, loss=0.728]
1%|1 | 32/2278 [00:04<05:34, 6.72it/s, loss=0.728]
1%|1 | 33/2278 [00:05<05:20, 7.00it/s, loss=0.722]
1%|1 | 34/2278 [00:05<05:12, 7.17it/s, loss=0.718]
2%|1 | 35/2278 [00:05<05:12, 7.17it/s, loss=0.723]
2%|1 | 36/2278 [00:05<05:05, 7.34it/s, loss=0.714]
2%|1 | 37/2278 [00:05<05:07, 7.29it/s, loss=0.710]
2%|1 | 38/2278 [00:05<05:21, 6.96it/s, loss=0.703]
2%|1 | 39/2278 [00:05<05:39, 6.59it/s, loss=0.704]
2%|1 | 40/2278 [00:06<05:32, 6.74it/s, loss=0.700]
2%|1 | 41/2278 [00:06<05:32, 6.73it/s, loss=0.707]
2%|1 | 42/2278 [00:06<05:27, 6.82it/s, loss=0.700]
2%|1 | 43/2278 [00:06<05:23, 6.92it/s, loss=0.701]
2%|1 | 44/2278 [00:06<05:20, 6.97it/s, loss=0.698]
2%|1 | 45/2278 [00:06<05:51, 6.36it/s, loss=0.701]
2%|2 | 46/2278 [00:06<05:55, 6.28it/s, loss=0.689]
2%|2 | 47/2278 [00:07<05:35, 6.66it/s, loss=0.694]
2%|2 | 48/2278 [00:07<05:27, 6.81it/s, loss=0.697]
2%|2 | 49/2278 [00:07<05:22, 6.90it/s, loss=0.685]
2%|2 | 50/2278 [00:07<05:17, 7.01it/s, loss=0.689]
2%|2 | 51/2278 [00:07<05:07, 7.23it/s, loss=0.689]
2%|2 | 52/2278 [00:07<05:12, 7.12it/s, loss=0.692]
2%|2 | 53/2278 [00:07<05:15, 7.05it/s, loss=0.688]
2%|2 | 54/2278 [00:08<05:15, 7.04it/s, loss=0.681]
2%|2 | 55/2278 [00:08<05:13, 7.09it/s, loss=0.689]
2%|2 | 56/2278 [00:08<05:16, 7.02it/s, loss=0.685]
3%|2 | 57/2278 [00:08<05:33, 6.65it/s, loss=0.680]
3%|2 | 58/2278 [00:08<05:36, 6.60it/s, loss=0.679]
3%|2 | 59/2278 [00:08<05:28, 6.76it/s, loss=0.680]
3%|2 | 60/2278 [00:08<05:29, 6.74it/s, loss=0.680]
3%|2 | 61/2278 [00:09<05:24, 6.84it/s, loss=0.683]
3%|2 | 62/2278 [00:09<05:19, 6.94it/s, loss=0.679]
3%|2 | 63/2278 [00:09<05:12, 7.10it/s, loss=0.677]
3%|2 | 64/2278 [00:09<05:03, 7.29it/s, loss=0.678]
3%|2 | 65/2278 [00:09<05:01, 7.33it/s, loss=0.681]
3%|2 | 66/2278 [00:09<05:04, 7.26it/s, loss=0.681]
3%|2 | 67/2278 [00:09<05:06, 7.21it/s, loss=0.679]
3%|2 | 68/2278 [00:10<05:08, 7.17it/s, loss=0.675]
3%|3 | 69/2278 [00:10<05:16, 6.97it/s, loss=0.675]
3%|3 | 70/2278 [00:10<05:23, 6.83it/s, loss=0.681]
3%|3 | 71/2278 [00:10<05:11, 7.09it/s, loss=0.675]
3%|3 | 72/2278 [00:10<05:10, 7.10it/s, loss=0.674]
3%|3 | 73/2278 [00:10<05:13, 7.04it/s, loss=0.674]
3%|3 | 74/2278 [00:10<05:31, 6.64it/s, loss=0.675]
3%|3 | 75/2278 [00:11<05:37, 6.52it/s, loss=0.675]
3%|3 | 76/2278 [00:11<05:40, 6.47it/s, loss=0.669]
3%|3 | 77/2278 [00:11<05:34, 6.59it/s, loss=0.674]
3%|3 | 78/2278 [00:11<05:34, 6.57it/s, loss=0.673]
3%|3 | 79/2278 [00:11<05:19, 6.89it/s, loss=0.668]
4%|3 | 80/2278 [00:11<05:21, 6.84it/s, loss=0.673]
4%|3 | 81/2278 [00:11<05:29, 6.66it/s, loss=0.674]
4%|3 | 82/2278 [00:12<05:21, 6.83it/s, loss=0.674]
4%|3 | 83/2278 [00:12<05:19, 6.87it/s, loss=0.669]
4%|3 | 84/2278 [00:12<05:20, 6.85it/s, loss=0.670]
4%|3 | 85/2278 [00:12<05:37, 6.49it/s, loss=0.670]
4%|3 | 86/2278 [00:12<05:32, 6.60it/s, loss=0.678]
4%|3 | 87/2278 [00:12<05:17, 6.91it/s, loss=0.669]
4%|3 | 88/2278 [00:13<05:16, 6.92it/s, loss=0.670]
4%|3 | 89/2278 [00:13<05:14, 6.97it/s, loss=0.672]
4%|3 | 90/2278 [00:13<05:14, 6.96it/s, loss=0.668]
4%|3 | 91/2278 [00:13<05:33, 6.55it/s, loss=0.673]
4%|4 | 92/2278 [00:13<05:24, 6.73it/s, loss=0.668]
4%|4 | 93/2278 [00:13<05:46, 6.31it/s, loss=0.668]
4%|4 | 94/2278 [00:13<05:29, 6.63it/s, loss=0.672]
4%|4 | 95/2278 [00:14<05:25, 6.71it/s, loss=0.667]
4%|4 | 96/2278 [00:14<05:16, 6.90it/s, loss=0.671]
4%|4 | 97/2278 [00:14<05:04, 7.16it/s, loss=0.667]
4%|4 | 98/2278 [00:14<05:05, 7.14it/s, loss=0.674]
4%|4 | 99/2278 [00:14<05:08, 7.06it/s, loss=0.663]
4%|4 | 100/2278 [00:14<05:07, 7.08it/s, loss=0.664]
4%|4 | 101/2278 [00:14<05:02, 7.19it/s, loss=0.666]
4%|4 | 102/2278 [00:15<05:32, 6.55it/s, loss=0.669]
5%|4 | 103/2278 [00:15<05:31, 6.57it/s, loss=0.670]
5%|4 | 104/2278 [00:15<05:22, 6.75it/s, loss=0.674]
5%|4 | 105/2278 [00:15<05:16, 6.87it/s, loss=0.669]
5%|4 | 106/2278 [00:15<05:18, 6.82it/s, loss=0.668]
5%|4 | 107/2278 [00:15<05:06, 7.09it/s, loss=0.667]
5%|4 | 108/2278 [00:15<04:57, 7.29it/s, loss=0.668]
5%|4 | 109/2278 [00:16<04:52, 7.43it/s, loss=0.669]
5%|4 | 110/2278 [00:16<05:14, 6.90it/s, loss=0.670]
5%|4 | 111/2278 [00:16<05:07, 7.05it/s, loss=0.674]
5%|4 | 112/2278 [00:16<05:13, 6.91it/s, loss=0.664]
5%|4 | 113/2278 [00:16<05:03, 7.14it/s, loss=0.665]
5%|5 | 114/2278 [00:16<05:02, 7.14it/s, loss=0.669]
5%|5 | 115/2278 [00:16<05:09, 6.98it/s, loss=0.670]
5%|5 | 116/2278 [00:17<05:09, 7.00it/s, loss=0.670]
5%|5 | 117/2278 [00:17<05:11, 6.94it/s, loss=0.669]
5%|5 | 118/2278 [00:17<05:08, 7.00it/s, loss=0.671]
5%|5 | 119/2278 [00:17<05:05, 7.08it/s, loss=0.665]
5%|5 | 120/2278 [00:17<05:07, 7.03it/s, loss=0.667]
5%|5 | 121/2278 [00:17<05:05, 7.06it/s, loss=0.663]
5%|5 | 122/2278 [00:17<05:07, 7.01it/s, loss=0.665]
5%|5 | 123/2278 [00:18<04:57, 7.23it/s, loss=0.669]
5%|5 | 124/2278 [00:18<05:14, 6.85it/s, loss=0.667]
5%|5 | 125/2278 [00:18<05:16, 6.80it/s, loss=0.667]
6%|5 | 126/2278 [00:18<05:12, 6.88it/s, loss=0.664]
6%|5 | 127/2278 [00:18<05:19, 6.73it/s, loss=0.669]
6%|5 | 128/2278 [00:18<05:19, 6.73it/s, loss=0.666]
6%|5 | 129/2278 [00:18<05:06, 7.01it/s, loss=0.667]
6%|5 | 130/2278 [00:19<05:05, 7.03it/s, loss=0.669]
6%|5 | 131/2278 [00:19<05:09, 6.94it/s, loss=0.667]
6%|5 | 132/2278 [00:19<05:17, 6.75it/s, loss=0.670]
6%|5 | 133/2278 [00:19<05:14, 6.83it/s, loss=0.668]
6%|5 | 134/2278 [00:19<05:12, 6.87it/s, loss=0.664]
6%|5 | 135/2278 [00:19<05:10, 6.90it/s, loss=0.666]
6%|5 | 136/2278 [00:19<05:09, 6.91it/s, loss=0.663]
6%|6 | 137/2278 [00:20<05:11, 6.88it/s, loss=0.666]
6%|6 | 138/2278 [00:20<05:23, 6.62it/s, loss=0.669]
6%|6 | 139/2278 [00:20<05:16, 6.76it/s, loss=0.666]
6%|6 | 140/2278 [00:20<05:12, 6.84it/s, loss=0.666]
6%|6 | 141/2278 [00:20<05:01, 7.10it/s, loss=0.664]
6%|6 | 142/2278 [00:20<05:00, 7.10it/s, loss=0.661]
6%|6 | 143/2278 [00:20<05:07, 6.95it/s, loss=0.667]
6%|6 | 144/2278 [00:21<05:02, 7.06it/s, loss=0.667]
6%|6 | 145/2278 [00:21<05:15, 6.76it/s, loss=0.669]
6%|6 | 146/2278 [00:21<05:03, 7.02it/s, loss=0.667]
6%|6 | 147/2278 [00:21<05:05, 6.97it/s, loss=0.667]
6%|6 | 148/2278 [00:21<05:13, 6.80it/s, loss=0.664]
7%|6 | 149/2278 [00:21<05:29, 6.46it/s, loss=0.669]
7%|6 | 150/2278 [00:22<05:37, 6.31it/s, loss=0.671]
7%|6 | 151/2278 [00:22<05:28, 6.47it/s, loss=0.664]
7%|6 | 152/2278 [00:22<05:32, 6.40it/s, loss=0.668]
7%|6 | 153/2278 [00:22<05:14, 6.76it/s, loss=0.661]
7%|6 | 154/2278 [00:22<05:10, 6.85it/s, loss=0.667]
7%|6 | 155/2278 [00:22<05:19, 6.64it/s, loss=0.667]
7%|6 | 156/2278 [00:22<05:21, 6.59it/s, loss=0.665]
7%|6 | 157/2278 [00:23<05:25, 6.52it/s, loss=0.663]
7%|6 | 158/2278 [00:23<05:20, 6.62it/s, loss=0.661]
7%|6 | 159/2278 [00:23<05:30, 6.41it/s, loss=0.669]
7%|7 | 160/2278 [00:23<05:13, 6.77it/s, loss=0.664]
7%|7 | 161/2278 [00:23<05:07, 6.88it/s, loss=0.663]
7%|7 | 162/2278 [00:23<05:04, 6.95it/s, loss=0.664]
7%|7 | 163/2278 [00:23<04:55, 7.17it/s, loss=0.662]
7%|7 | 164/2278 [00:24<05:24, 6.51it/s, loss=0.663]
7%|7 | 165/2278 [00:24<05:08, 6.86it/s, loss=0.664]
7%|7 | 166/2278 [00:24<05:06, 6.90it/s, loss=0.663]
7%|7 | 167/2278 [00:24<05:05, 6.92it/s, loss=0.665]
7%|7 | 168/2278 [00:24<05:11, 6.78it/s, loss=0.664]
7%|7 | 169/2278 [00:24<04:59, 7.05it/s, loss=0.667]
7%|7 | 170/2278 [00:24<05:03, 6.96it/s, loss=0.666]
8%|7 | 171/2278 [00:25<05:13, 6.72it/s, loss=0.660]
8%|7 | 172/2278 [00:25<05:17, 6.63it/s, loss=0.663]
8%|7 | 173/2278 [00:25<05:28, 6.41it/s, loss=0.660]
8%|7 | 174/2278 [00:25<05:11, 6.75it/s, loss=0.660]
8%|7 | 175/2278 [00:25<05:07, 6.85it/s, loss=0.663]
8%|7 | 176/2278 [00:25<04:55, 7.12it/s, loss=0.661]
8%|7 | 177/2278 [00:25<04:47, 7.31it/s, loss=0.665]
8%|7 | 178/2278 [00:26<04:52, 7.18it/s, loss=0.666]
8%|7 | 179/2278 [00:26<04:59, 7.00it/s, loss=0.663]
8%|7 | 180/2278 [00:26<04:53, 7.14it/s, loss=0.658]
8%|7 | 181/2278 [00:26<04:57, 7.04it/s, loss=0.662]
8%|7 | 182/2278 [00:26<04:56, 7.07it/s, loss=0.661]
8%|8 | 183/2278 [00:26<04:55, 7.10it/s, loss=0.667]
8%|8 | 184/2278 [00:26<04:53, 7.12it/s, loss=0.660]
8%|8 | 185/2278 [00:27<04:54, 7.12it/s, loss=0.663]
8%|8 | 186/2278 [00:27<04:46, 7.30it/s, loss=0.661]
8%|8 | 187/2278 [00:27<04:47, 7.28it/s, loss=0.662]
8%|8 | 188/2278 [00:27<04:54, 7.11it/s, loss=0.664]
8%|8 | 189/2278 [00:27<04:46, 7.29it/s, loss=0.658]
8%|8 | 190/2278 [00:27<04:54, 7.08it/s, loss=0.664]
8%|8 | 191/2278 [00:27<04:56, 7.03it/s, loss=0.665]
8%|8 | 192/2278 [00:28<04:53, 7.10it/s, loss=0.667]
8%|8 | 193/2278 [00:28<04:54, 7.08it/s, loss=0.662]
9%|8 | 194/2278 [00:28<04:54, 7.08it/s, loss=0.658]
9%|8 | 195/2278 [00:28<05:04, 6.85it/s, loss=0.665]
9%|8 | 196/2278 [00:28<04:59, 6.94it/s, loss=0.658]
9%|8 | 197/2278 [00:28<04:58, 6.97it/s, loss=0.662]
9%|8 | 198/2278 [00:28<04:49, 7.19it/s, loss=0.659]
9%|8 | 199/2278 [00:29<04:42, 7.36it/s, loss=0.660]
9%|8 | 200/2278 [00:29<04:43, 7.33it/s, loss=0.660]
9%|8 | 201/2278 [00:29<04:49, 7.18it/s, loss=0.660]
9%|8 | 202/2278 [00:29<04:49, 7.17it/s, loss=0.666]
9%|8 | 203/2278 [00:29<04:49, 7.16it/s, loss=0.666]
9%|8 | 204/2278 [00:29<04:50, 7.14it/s, loss=0.663]
9%|8 | 205/2278 [00:29<04:59, 6.91it/s, loss=0.661]
9%|9 | 206/2278 [00:30<04:53, 7.06it/s, loss=0.659]
9%|9 | 207/2278 [00:30<04:50, 7.13it/s, loss=0.661]
9%|9 | 208/2278 [00:30<04:43, 7.31it/s, loss=0.661]
9%|9 | 209/2278 [00:30<04:53, 7.04it/s, loss=0.665]
9%|9 | 210/2278 [00:30<04:52, 7.08it/s, loss=0.664]
9%|9 | 211/2278 [00:30<04:53, 7.04it/s, loss=0.663]
9%|9 | 212/2278 [00:30<04:50, 7.10it/s, loss=0.659]
9%|9 | 213/2278 [00:31<04:52, 7.06it/s, loss=0.659]
9%|9 | 214/2278 [00:31<04:52, 7.05it/s, loss=0.663]
9%|9 | 215/2278 [00:31<04:51, 7.07it/s, loss=0.663]
9%|9 | 216/2278 [00:31<04:50, 7.10it/s, loss=0.658]
10%|9 | 217/2278 [00:31<04:58, 6.90it/s, loss=0.661]
10%|9 | 218/2278 [00:31<04:59, 6.87it/s, loss=0.661]
10%|9 | 219/2278 [00:31<04:49, 7.12it/s, loss=0.664]
10%|9 | 220/2278 [00:32<04:57, 6.91it/s, loss=0.657]
10%|9 | 221/2278 [00:32<04:57, 6.93it/s, loss=0.659]
10%|9 | 222/2278 [00:32<04:53, 7.00it/s, loss=0.659]
10%|9 | 223/2278 [00:32<05:06, 6.71it/s, loss=0.666]
10%|9 | 224/2278 [00:32<05:25, 6.31it/s, loss=0.659]
10%|9 | 225/2278 [00:32<05:17, 6.47it/s, loss=0.667]
10%|9 | 226/2278 [00:32<05:12, 6.57it/s, loss=0.660]
10%|9 | 227/2278 [00:33<05:01, 6.81it/s, loss=0.660]
10%|# | 228/2278 [00:33<05:05, 6.72it/s, loss=0.659]
10%|# | 229/2278 [00:33<05:03, 6.75it/s, loss=0.664]
10%|# | 230/2278 [00:33<05:12, 6.55it/s, loss=0.665]
10%|# | 231/2278 [00:33<05:07, 6.66it/s, loss=0.663]
10%|# | 232/2278 [00:33<04:53, 6.96it/s, loss=0.665]
10%|# | 233/2278 [00:33<04:50, 7.03it/s, loss=0.660]
10%|# | 234/2278 [00:34<04:48, 7.07it/s, loss=0.662]
10%|# | 235/2278 [00:34<04:59, 6.83it/s, loss=0.665]
10%|# | 236/2278 [00:34<05:08, 6.61it/s, loss=0.663]
10%|# | 237/2278 [00:34<04:54, 6.93it/s, loss=0.656]
10%|# | 238/2278 [00:34<04:44, 7.16it/s, loss=0.665]
10%|# | 239/2278 [00:34<04:55, 6.91it/s, loss=0.659]
11%|# | 240/2278 [00:35<05:22, 6.32it/s, loss=0.665]
11%|# | 241/2278 [00:35<05:09, 6.58it/s, loss=0.658]
11%|# | 242/2278 [00:35<05:00, 6.77it/s, loss=0.660]
11%|# | 243/2278 [00:35<04:57, 6.85it/s, loss=0.665]
11%|# | 244/2278 [00:35<05:02, 6.73it/s, loss=0.663]
11%|# | 245/2278 [00:35<05:14, 6.47it/s, loss=0.660]
11%|# | 246/2278 [00:35<04:58, 6.81it/s, loss=0.657]
11%|# | 247/2278 [00:36<04:58, 6.81it/s, loss=0.659]
11%|# | 248/2278 [00:36<04:47, 7.07it/s, loss=0.656]
11%|# | 249/2278 [00:36<04:39, 7.27it/s, loss=0.657]
11%|# | 250/2278 [00:36<04:39, 7.26it/s, loss=0.663]
11%|#1 | 251/2278 [00:36<04:39, 7.25it/s, loss=0.661]
11%|#1 | 252/2278 [00:36<05:04, 6.65it/s, loss=0.659]
11%|#1 | 253/2278 [00:36<05:01, 6.72it/s, loss=0.660]
11%|#1 | 254/2278 [00:37<05:06, 6.59it/s, loss=0.657]
11%|#1 | 255/2278 [00:37<04:52, 6.92it/s, loss=0.658]
11%|#1 | 256/2278 [00:37<04:42, 7.17it/s, loss=0.655]
11%|#1 | 257/2278 [00:37<04:35, 7.34it/s, loss=0.665]
11%|#1 | 258/2278 [00:37<04:39, 7.22it/s, loss=0.656]
11%|#1 | 259/2278 [00:37<04:33, 7.39it/s, loss=0.661]
11%|#1 | 260/2278 [00:37<04:39, 7.23it/s, loss=0.665]
11%|#1 | 261/2278 [00:38<05:07, 6.56it/s, loss=0.664]
12%|#1 | 262/2278 [00:38<05:18, 6.32it/s, loss=0.661]
12%|#1 | 263/2278 [00:38<05:08, 6.53it/s, loss=0.655]
12%|#1 | 264/2278 [00:38<04:53, 6.86it/s, loss=0.660]
12%|#1 | 265/2278 [00:38<04:57, 6.77it/s, loss=0.664]
12%|#1 | 266/2278 [00:38<05:07, 6.54it/s, loss=0.661]
12%|#1 | 267/2278 [00:38<05:01, 6.67it/s, loss=0.660]
12%|#1 | 268/2278 [00:39<04:54, 6.83it/s, loss=0.659]
12%|#1 | 269/2278 [00:39<04:49, 6.95it/s, loss=0.660]
12%|#1 | 270/2278 [00:39<04:47, 7.00it/s, loss=0.660]
12%|#1 | 271/2278 [00:39<04:55, 6.79it/s, loss=0.659]
12%|#1 | 272/2278 [00:39<04:51, 6.88it/s, loss=0.661]
12%|#1 | 273/2278 [00:39<04:47, 6.97it/s, loss=0.656]
12%|#2 | 274/2278 [00:39<04:53, 6.83it/s, loss=0.660]
12%|#2 | 275/2278 [00:40<04:42, 7.09it/s, loss=0.660]
12%|#2 | 276/2278 [00:40<04:44, 7.05it/s, loss=0.659]
12%|#2 | 277/2278 [00:40<04:47, 6.96it/s, loss=0.656]
12%|#2 | 278/2278 [00:40<04:45, 7.01it/s, loss=0.658]
12%|#2 | 279/2278 [00:40<04:47, 6.94it/s, loss=0.660]
12%|#2 | 280/2278 [00:40<04:45, 7.01it/s, loss=0.658]
12%|#2 | 281/2278 [00:40<05:00, 6.65it/s, loss=0.654]
12%|#2 | 282/2278 [00:41<04:47, 6.95it/s, loss=0.663]
12%|#2 | 283/2278 [00:41<04:37, 7.19it/s, loss=0.660]
12%|#2 | 284/2278 [00:41<04:30, 7.37it/s, loss=0.668]
13%|#2 | 285/2278 [00:41<04:25, 7.50it/s, loss=0.662]
13%|#2 | 286/2278 [00:41<04:22, 7.59it/s, loss=0.659]
13%|#2 | 287/2278 [00:41<04:20, 7.65it/s, loss=0.657]
13%|#2 | 288/2278 [00:41<04:18, 7.70it/s, loss=0.664]
13%|#2 | 289/2278 [00:42<04:25, 7.48it/s, loss=0.658]
13%|#2 | 290/2278 [00:42<04:41, 7.06it/s, loss=0.658]
13%|#2 | 291/2278 [00:42<04:39, 7.12it/s, loss=0.658]
13%|#2 | 292/2278 [00:42<04:47, 6.92it/s, loss=0.664]
13%|#2 | 293/2278 [00:42<04:46, 6.92it/s, loss=0.658]
13%|#2 | 294/2278 [00:42<04:37, 7.16it/s, loss=0.659]
13%|#2 | 295/2278 [00:42<04:36, 7.16it/s, loss=0.661]
13%|#2 | 296/2278 [00:43<04:47, 6.88it/s, loss=0.657]
13%|#3 | 297/2278 [00:43<04:49, 6.85it/s, loss=0.660]
13%|#3 | 298/2278 [00:43<04:56, 6.69it/s, loss=0.654]
13%|#3 | 299/2278 [00:43<05:00, 6.59it/s, loss=0.663]
13%|#3 | 300/2278 [00:43<04:46, 6.90it/s, loss=0.661]
13%|#3 | 301/2278 [00:43<04:54, 6.71it/s, loss=0.663]
13%|#3 | 302/2278 [00:43<04:43, 6.97it/s, loss=0.654]
13%|#3 | 303/2278 [00:44<04:57, 6.63it/s, loss=0.659]
13%|#3 | 304/2278 [00:44<05:05, 6.46it/s, loss=0.655]
13%|#3 | 305/2278 [00:44<05:09, 6.37it/s, loss=0.658]
13%|#3 | 306/2278 [00:44<04:58, 6.61it/s, loss=0.659]
13%|#3 | 307/2278 [00:44<05:06, 6.43it/s, loss=0.656]
14%|#3 | 308/2278 [00:44<04:50, 6.78it/s, loss=0.662]
14%|#3 | 309/2278 [00:45<04:52, 6.74it/s, loss=0.659]
14%|#3 | 310/2278 [00:45<04:53, 6.70it/s, loss=0.656]
14%|#3 | 311/2278 [00:45<04:55, 6.66it/s, loss=0.661]
14%|#3 | 312/2278 [00:45<04:50, 6.78it/s, loss=0.661]
14%|#3 | 313/2278 [00:45<04:38, 7.05it/s, loss=0.662]
14%|#3 | 314/2278 [00:45<04:36, 7.10it/s, loss=0.653]
14%|#3 | 315/2278 [00:45<04:49, 6.78it/s, loss=0.657]
14%|#3 | 316/2278 [00:46<04:39, 7.02it/s, loss=0.655]
14%|#3 | 317/2278 [00:46<04:51, 6.73it/s, loss=0.662]
14%|#3 | 318/2278 [00:46<04:40, 7.00it/s, loss=0.661]
14%|#4 | 319/2278 [00:46<05:03, 6.46it/s, loss=0.656]
14%|#4 | 320/2278 [00:46<05:07, 6.37it/s, loss=0.657]
14%|#4 | 321/2278 [00:46<05:08, 6.35it/s, loss=0.660]
14%|#4 | 322/2278 [00:46<04:58, 6.55it/s, loss=0.665]
14%|#4 | 323/2278 [00:47<04:54, 6.64it/s, loss=0.661]
14%|#4 | 324/2278 [00:47<05:05, 6.39it/s, loss=0.662]
14%|#4 | 325/2278 [00:47<05:04, 6.41it/s, loss=0.663]
14%|#4 | 326/2278 [00:47<04:56, 6.58it/s, loss=0.654]
14%|#4 | 327/2278 [00:47<04:50, 6.73it/s, loss=0.657]
14%|#4 | 328/2278 [00:47<04:38, 7.01it/s, loss=0.656]
14%|#4 | 329/2278 [00:47<04:33, 7.13it/s, loss=0.654]
14%|#4 | 330/2278 [00:48<04:32, 7.15it/s, loss=0.663]
15%|#4 | 331/2278 [00:48<04:46, 6.80it/s, loss=0.654]
15%|#4 | 332/2278 [00:48<04:48, 6.74it/s, loss=0.654]
15%|#4 | 333/2278 [00:48<05:06, 6.35it/s, loss=0.653]
15%|#4 | 334/2278 [00:48<05:08, 6.29it/s, loss=0.660]
15%|#4 | 335/2278 [00:48<04:55, 6.56it/s, loss=0.663]
15%|#4 | 336/2278 [00:49<05:02, 6.42it/s, loss=0.653]
15%|#4 | 337/2278 [00:49<04:46, 6.77it/s, loss=0.660]
15%|#4 | 338/2278 [00:49<04:54, 6.59it/s, loss=0.662]
15%|#4 | 339/2278 [00:49<04:50, 6.68it/s, loss=0.659]
15%|#4 | 340/2278 [00:49<04:42, 6.85it/s, loss=0.658]
15%|#4 | 341/2278 [00:49<04:32, 7.11it/s, loss=0.657]
15%|#5 | 342/2278 [00:49<04:38, 6.94it/s, loss=0.651]
15%|#5 | 343/2278 [00:50<04:47, 6.74it/s, loss=0.657]
15%|#5 | 344/2278 [00:50<04:53, 6.59it/s, loss=0.660]
15%|#5 | 345/2278 [00:50<05:00, 6.42it/s, loss=0.661]
15%|#5 | 346/2278 [00:50<04:47, 6.72it/s, loss=0.654]
15%|#5 | 347/2278 [00:50<04:41, 6.85it/s, loss=0.655]
15%|#5 | 348/2278 [00:50<04:31, 7.10it/s, loss=0.655]
15%|#5 | 349/2278 [00:50<04:24, 7.30it/s, loss=0.652]
15%|#5 | 350/2278 [00:51<04:18, 7.45it/s, loss=0.656]
15%|#5 | 351/2278 [00:51<04:15, 7.56it/s, loss=0.659]
15%|#5 | 352/2278 [00:51<04:32, 7.07it/s, loss=0.662]
15%|#5 | 353/2278 [00:51<04:27, 7.18it/s, loss=0.658]
16%|#5 | 354/2278 [00:51<04:23, 7.29it/s, loss=0.653]
16%|#5 | 355/2278 [00:51<04:17, 7.46it/s, loss=0.660]
16%|#5 | 356/2278 [00:51<04:14, 7.54it/s, loss=0.650]
16%|#5 | 357/2278 [00:52<04:16, 7.49it/s, loss=0.659]
16%|#5 | 358/2278 [00:52<04:23, 7.29it/s, loss=0.656]
16%|#5 | 359/2278 [00:52<04:31, 7.07it/s, loss=0.661]
16%|#5 | 360/2278 [00:52<04:46, 6.69it/s, loss=0.660]
16%|#5 | 361/2278 [00:52<04:41, 6.80it/s, loss=0.658]
16%|#5 | 362/2278 [00:52<04:50, 6.59it/s, loss=0.653]
16%|#5 | 363/2278 [00:52<04:44, 6.73it/s, loss=0.654]
16%|#5 | 364/2278 [00:53<04:35, 6.94it/s, loss=0.661]
16%|#6 | 365/2278 [00:53<04:32, 7.01it/s, loss=0.658]
16%|#6 | 366/2278 [00:53<04:24, 7.23it/s, loss=0.659]
16%|#6 | 367/2278 [00:53<04:28, 7.13it/s, loss=0.650]
16%|#6 | 368/2278 [00:53<04:34, 6.97it/s, loss=0.661]
16%|#6 | 369/2278 [00:53<04:31, 7.02it/s, loss=0.656]
16%|#6 | 370/2278 [00:53<04:24, 7.22it/s, loss=0.656]
16%|#6 | 371/2278 [00:54<04:24, 7.22it/s, loss=0.659]
16%|#6 | 372/2278 [00:54<04:38, 6.83it/s, loss=0.657]
16%|#6 | 373/2278 [00:54<04:43, 6.72it/s, loss=0.659]
16%|#6 | 374/2278 [00:54<04:38, 6.84it/s, loss=0.656]
16%|#6 | 375/2278 [00:54<04:36, 6.88it/s, loss=0.657]
17%|#6 | 376/2278 [00:54<04:44, 6.69it/s, loss=0.660]
17%|#6 | 377/2278 [00:54<04:50, 6.54it/s, loss=0.650]
17%|#6 | 378/2278 [00:55<04:53, 6.47it/s, loss=0.662]
17%|#6 | 379/2278 [00:55<04:44, 6.67it/s, loss=0.654]
17%|#6 | 380/2278 [00:55<04:32, 6.97it/s, loss=0.657]
17%|#6 | 381/2278 [00:55<04:40, 6.77it/s, loss=0.659]
17%|#6 | 382/2278 [00:55<04:36, 6.86it/s, loss=0.662]
17%|#6 | 383/2278 [00:55<04:33, 6.92it/s, loss=0.656]
17%|#6 | 384/2278 [00:55<04:37, 6.83it/s, loss=0.654]
17%|#6 | 385/2278 [00:56<04:26, 7.10it/s, loss=0.655]
17%|#6 | 386/2278 [00:56<04:19, 7.30it/s, loss=0.653]
17%|#6 | 387/2278 [00:56<04:13, 7.45it/s, loss=0.662]
17%|#7 | 388/2278 [00:56<04:16, 7.38it/s, loss=0.658]
17%|#7 | 389/2278 [00:56<04:31, 6.97it/s, loss=0.654]
17%|#7 | 390/2278 [00:56<04:35, 6.85it/s, loss=0.652]
17%|#7 | 391/2278 [00:56<04:32, 6.93it/s, loss=0.650]
17%|#7 | 392/2278 [00:57<04:30, 6.99it/s, loss=0.659]
17%|#7 | 393/2278 [00:57<04:31, 6.94it/s, loss=0.658]
17%|#7 | 394/2278 [00:57<04:28, 7.01it/s, loss=0.659]
17%|#7 | 395/2278 [00:57<04:41, 6.70it/s, loss=0.663]
17%|#7 | 396/2278 [00:57<04:29, 6.99it/s, loss=0.651]
17%|#7 | 397/2278 [00:57<04:21, 7.20it/s, loss=0.654]
17%|#7 | 398/2278 [00:57<04:33, 6.89it/s, loss=0.657]
18%|#7 | 399/2278 [00:58<04:39, 6.73it/s, loss=0.653]
18%|#7 | 400/2278 [00:58<04:49, 6.49it/s, loss=0.657]
18%|#7 | 401/2278 [00:58<04:34, 6.84it/s, loss=0.654]
18%|#7 | 402/2278 [00:58<04:30, 6.95it/s, loss=0.657]
18%|#7 | 403/2278 [00:58<04:28, 6.97it/s, loss=0.658]
18%|#7 | 404/2278 [00:58<04:26, 7.04it/s, loss=0.651]
18%|#7 | 405/2278 [00:58<04:35, 6.79it/s, loss=0.651]
18%|#7 | 406/2278 [00:59<04:40, 6.68it/s, loss=0.654]
18%|#7 | 407/2278 [00:59<04:51, 6.42it/s, loss=0.655]
18%|#7 | 408/2278 [00:59<04:35, 6.78it/s, loss=0.658]
18%|#7 | 409/2278 [00:59<04:31, 6.88it/s, loss=0.660]
18%|#7 | 410/2278 [00:59<04:37, 6.73it/s, loss=0.658]
18%|#8 | 411/2278 [00:59<04:36, 6.76it/s, loss=0.648]
18%|#8 | 412/2278 [01:00<04:25, 7.02it/s, loss=0.657]
18%|#8 | 413/2278 [01:00<04:41, 6.64it/s, loss=0.655]
18%|#8 | 414/2278 [01:00<04:35, 6.75it/s, loss=0.654]
18%|#8 | 415/2278 [01:00<04:35, 6.76it/s, loss=0.651]
18%|#8 | 416/2278 [01:00<04:39, 6.66it/s, loss=0.656]
18%|#8 | 417/2278 [01:00<04:50, 6.40it/s, loss=0.652]
18%|#8 | 418/2278 [01:00<04:35, 6.74it/s, loss=0.651]
18%|#8 | 419/2278 [01:01<04:37, 6.69it/s, loss=0.652]
18%|#8 | 420/2278 [01:01<04:31, 6.84it/s, loss=0.654]
18%|#8 | 421/2278 [01:01<04:26, 6.97it/s, loss=0.651]
19%|#8 | 422/2278 [01:01<04:23, 7.04it/s, loss=0.652]
19%|#8 | 423/2278 [01:01<04:27, 6.95it/s, loss=0.654]
19%|#8 | 424/2278 [01:01<04:34, 6.76it/s, loss=0.650]
19%|#8 | 425/2278 [01:01<04:28, 6.89it/s, loss=0.656]
19%|#8 | 426/2278 [01:02<04:20, 7.12it/s, loss=0.655]
19%|#8 | 427/2278 [01:02<04:19, 7.14it/s, loss=0.652]
19%|#8 | 428/2278 [01:02<04:12, 7.32it/s, loss=0.651]
19%|#8 | 429/2278 [01:02<04:23, 7.01it/s, loss=0.651]
19%|#8 | 430/2278 [01:02<04:25, 6.97it/s, loss=0.655]
19%|#8 | 431/2278 [01:02<04:36, 6.69it/s, loss=0.655]
19%|#8 | 432/2278 [01:02<04:31, 6.80it/s, loss=0.654]
19%|#9 | 433/2278 [01:03<04:21, 7.06it/s, loss=0.647]
19%|#9 | 434/2278 [01:03<04:13, 7.28it/s, loss=0.659]
19%|#9 | 435/2278 [01:03<04:08, 7.42it/s, loss=0.651]
19%|#9 | 436/2278 [01:03<04:11, 7.33it/s, loss=0.656]
19%|#9 | 437/2278 [01:03<04:16, 7.19it/s, loss=0.653]
19%|#9 | 438/2278 [01:03<04:30, 6.80it/s, loss=0.658]
19%|#9 | 439/2278 [01:03<04:28, 6.85it/s, loss=0.653]
19%|#9 | 440/2278 [01:04<04:36, 6.65it/s, loss=0.659]
19%|#9 | 441/2278 [01:04<04:42, 6.49it/s, loss=0.654]
19%|#9 | 442/2278 [01:04<04:28, 6.83it/s, loss=0.652]
19%|#9 | 443/2278 [01:04<04:21, 7.03it/s, loss=0.651]
19%|#9 | 444/2278 [01:04<04:13, 7.24it/s, loss=0.656]
20%|#9 | 445/2278 [01:04<04:07, 7.41it/s, loss=0.655]
20%|#9 | 446/2278 [01:04<04:29, 6.80it/s, loss=0.660]
20%|#9 | 447/2278 [01:05<04:27, 6.85it/s, loss=0.657]
20%|#9 | 448/2278 [01:05<04:36, 6.61it/s, loss=0.660]
20%|#9 | 449/2278 [01:05<04:26, 6.86it/s, loss=0.658]
20%|#9 | 450/2278 [01:05<04:17, 7.11it/s, loss=0.651]
20%|#9 | 451/2278 [01:05<04:17, 7.10it/s, loss=0.656]
20%|#9 | 452/2278 [01:05<04:26, 6.85it/s, loss=0.652]
20%|#9 | 453/2278 [01:05<04:28, 6.79it/s, loss=0.654]
20%|#9 | 454/2278 [01:06<04:32, 6.69it/s, loss=0.657]
20%|#9 | 455/2278 [01:06<04:26, 6.83it/s, loss=0.656]
20%|## | 456/2278 [01:06<04:17, 7.09it/s, loss=0.655]
20%|## | 457/2278 [01:06<04:10, 7.28it/s, loss=0.655]
20%|## | 458/2278 [01:06<04:06, 7.37it/s, loss=0.651]
20%|## | 459/2278 [01:06<04:08, 7.33it/s, loss=0.655]
20%|## | 460/2278 [01:06<04:17, 7.06it/s, loss=0.648]
20%|## | 461/2278 [01:07<04:10, 7.27it/s, loss=0.653]
20%|## | 462/2278 [01:07<04:10, 7.26it/s, loss=0.654]
20%|## | 463/2278 [01:07<04:17, 7.04it/s, loss=0.655]
20%|## | 464/2278 [01:07<04:16, 7.08it/s, loss=0.650]
20%|## | 465/2278 [01:07<04:12, 7.18it/s, loss=0.657]
20%|## | 466/2278 [01:07<04:12, 7.17it/s, loss=0.653]
21%|## | 467/2278 [01:07<04:25, 6.82it/s, loss=0.656]
21%|## | 468/2278 [01:08<04:22, 6.90it/s, loss=0.654]
21%|## | 469/2278 [01:08<04:20, 6.94it/s, loss=0.651]
21%|## | 470/2278 [01:08<04:18, 6.99it/s, loss=0.647]
21%|## | 471/2278 [01:08<04:16, 7.05it/s, loss=0.651]
21%|## | 472/2278 [01:08<04:13, 7.11it/s, loss=0.655]
21%|## | 473/2278 [01:08<04:12, 7.16it/s, loss=0.649]
21%|## | 474/2278 [01:08<04:12, 7.15it/s, loss=0.654]
21%|## | 475/2278 [01:09<04:22, 6.88it/s, loss=0.654]
21%|## | 476/2278 [01:09<04:19, 6.95it/s, loss=0.652]
21%|## | 477/2278 [01:09<04:18, 6.98it/s, loss=0.653]
21%|## | 478/2278 [01:09<04:24, 6.81it/s, loss=0.652]
21%|##1 | 479/2278 [01:09<04:21, 6.89it/s, loss=0.657]
21%|##1 | 480/2278 [01:09<04:16, 7.01it/s, loss=0.654]
21%|##1 | 481/2278 [01:09<04:08, 7.22it/s, loss=0.652]
21%|##1 | 482/2278 [01:10<04:15, 7.02it/s, loss=0.650]
21%|##1 | 483/2278 [01:10<04:07, 7.24it/s, loss=0.655]
21%|##1 | 484/2278 [01:10<04:02, 7.40it/s, loss=0.648]
21%|##1 | 485/2278 [01:10<03:58, 7.53it/s, loss=0.652]
21%|##1 | 486/2278 [01:10<04:02, 7.38it/s, loss=0.659]
21%|##1 | 487/2278 [01:10<04:16, 6.99it/s, loss=0.651]
21%|##1 | 488/2278 [01:10<04:23, 6.78it/s, loss=0.653]
21%|##1 | 489/2278 [01:11<04:18, 6.91it/s, loss=0.655]
22%|##1 | 490/2278 [01:11<04:34, 6.51it/s, loss=0.654]
22%|##1 | 491/2278 [01:11<04:31, 6.59it/s, loss=0.656]
22%|##1 | 492/2278 [01:11<04:18, 6.92it/s, loss=0.652]
22%|##1 | 493/2278 [01:11<04:09, 7.16it/s, loss=0.651]
22%|##1 | 494/2278 [01:11<04:09, 7.16it/s, loss=0.651]
22%|##1 | 495/2278 [01:11<04:03, 7.33it/s, loss=0.658]
22%|##1 | 496/2278 [01:12<04:04, 7.28it/s, loss=0.656]
22%|##1 | 497/2278 [01:12<04:21, 6.82it/s, loss=0.650]
22%|##1 | 498/2278 [01:12<04:19, 6.85it/s, loss=0.654]
22%|##1 | 499/2278 [01:12<04:19, 6.84it/s, loss=0.652]Converges at iteration 10
Epoch 0 Validation Accuracy 0.07674754186382093 Test Accuracy 0.059070427751373375
22%|##1 | 499/2278 [01:31<05:27, 5.43it/s, loss=0.652]
Evaluating Performance with Link Prediction (Optional)¶
In practice, it is more common to evaluate the link prediction model to see whether it can predict new edges. There are different evaluation metrics such as AUC or various metrics from information retrieval. Ultimately, they require the model to predict one scalar score given a node pair among a set of node pairs.
Assuming that you have the following test set with labels, where
test_pos_src
and test_pos_dst
are ground truth node pairs
with edges in between (or positive pairs), and test_neg_src
and test_neg_dst
are ground truth node pairs without edges
in between (or negative pairs).
# Positive pairs
# These are randomly generated as an example. You will need to
# replace them with your own ground truth.
n_test_pos = 1000
test_pos_src, test_pos_dst = (
torch.randint(0, graph.num_nodes(), (n_test_pos,)),
torch.randint(0, graph.num_nodes(), (n_test_pos,)))
# Negative pairs. Likewise, you will need to replace them with your
# own ground truth.
test_neg_src = test_pos_src
test_neg_dst = torch.randint(0, graph.num_nodes(), (n_test_pos,))
First you need to compute the node representations for all the nodes
with the inference
method above:
node_reprs = inference(model, graph, node_features)
Since the predictor is a dot product, you can now easily compute the score of positive and negative test pairs to compute metrics such as AUC:
h_pos_src = node_reprs[test_pos_src]
h_pos_dst = node_reprs[test_pos_dst]
h_neg_src = node_reprs[test_neg_src]
h_neg_dst = node_reprs[test_neg_dst]
score_pos = (h_pos_src * h_pos_dst).sum(1)
score_neg = (h_neg_src * h_neg_dst).sum(1)
test_preds = torch.cat([score_pos, score_neg]).cpu().numpy()
test_labels = torch.cat([torch.ones_like(score_pos), torch.zeros_like(score_neg)]).cpu().numpy()
auc = sklearn.metrics.roc_auc_score(test_labels, test_preds)
print('Link Prediction AUC:', auc)
Out:
Link Prediction AUC: 0.532262
Conclusion¶
In this tutorial, you have learned how to train a multi-layer GraphSAGE for link prediction with neighbor sampling.
# Thumbnail credits: Link Prediction with Neo4j, Mark Needham
# sphinx_gallery_thumbnail_path = '_static/blitz_4_link_predict.png'
Total running time of the script: ( 1 minutes 36.605 seconds)