Source code for dgl.nn.pytorch.network_emb

"""Network Embedding NN Modules"""
# pylint: disable= invalid-name

import random
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F
import tqdm

from ...base import NID
from ...convert import to_homogeneous, to_heterogeneous
from ...random import choice
from ...sampling import random_walk

__all__ = ['DeepWalk', 'MetaPath2Vec']

[docs]class DeepWalk(nn.Module): """DeepWalk module from `DeepWalk: Online Learning of Social Representations <https://arxiv.org/abs/1403.6652>`__ For a graph, it learns the node representations from scratch by maximizing the similarity of node pairs that are nearby (positive node pairs) and minimizing the similarity of other random node pairs (negative node pairs). Parameters ---------- g : DGLGraph Graph for learning node embeddings emb_dim : int, optional Size of each embedding vector. Default: 128 walk_length : int, optional Number of nodes in a random walk sequence. Default: 40 window_size : int, optional In a random walk :attr:`w`, a node :attr:`w[j]` is considered close to a node :attr:`w[i]` if :attr:`i - window_size <= j <= i + window_size`. Default: 5 neg_weight : float, optional Weight of the loss term for negative samples in the total loss. Default: 1.0 negative_size : int, optional Number of negative samples to use for each positive sample. Default: 1 fast_neg : bool, optional If True, it samples negative node pairs within a batch of random walks. Default: True sparse : bool, optional If True, gradients with respect to the learnable weights will be sparse. Default: True Attributes ---------- node_embed : nn.Embedding Embedding table of the nodes Examples -------- >>> import torch >>> from dgl.data import CoraGraphDataset >>> from dgl.nn import DeepWalk >>> from torch.optim import SparseAdam >>> from torch.utils.data import DataLoader >>> from sklearn.linear_model import LogisticRegression >>> dataset = CoraGraphDataset() >>> g = dataset[0] >>> model = DeepWalk(g) >>> dataloader = DataLoader(torch.arange(g.num_nodes()), batch_size=128, ... shuffle=True, collate_fn=model.sample) >>> optimizer = SparseAdam(model.parameters(), lr=0.01) >>> num_epochs = 5 >>> for epoch in range(num_epochs): ... for batch_walk in dataloader: ... loss = model(batch_walk) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step() >>> train_mask = g.ndata['train_mask'] >>> test_mask = g.ndata['test_mask'] >>> X = model.node_embed.weight.detach() >>> y = g.ndata['label'] >>> clf = LogisticRegression().fit(X[train_mask].numpy(), y[train_mask].numpy()) >>> clf.score(X[test_mask].numpy(), y[test_mask].numpy()) """ def __init__(self, g, emb_dim=128, walk_length=40, window_size=5, neg_weight=1, negative_size=5, fast_neg=True, sparse=True): super().__init__() assert walk_length >= window_size + 1, \ f'Expect walk_length >= window_size + 1, got {walk_length} and {window_size + 1}' self.g = g self.emb_dim = emb_dim self.window_size = window_size self.walk_length = walk_length self.neg_weight = neg_weight self.negative_size = negative_size self.fast_neg = fast_neg num_nodes = g.num_nodes() # center node embedding self.node_embed = nn.Embedding(num_nodes, emb_dim, sparse=sparse) self.context_embed = nn.Embedding(num_nodes, emb_dim, sparse=sparse) self.reset_parameters() if not fast_neg: neg_prob = g.out_degrees().pow(0.75) # categorical distribution for true negative sampling self.neg_prob = neg_prob / neg_prob.sum() # Get list index pairs for positive samples. # Given i, positive index pairs are (i - window_size, i), ... , # (i - 1, i), (i + 1, i), ..., (i + window_size, i) idx_list_src = [] idx_list_dst = [] for i in range(walk_length): for j in range(max(0, i - window_size), i): idx_list_src.append(j) idx_list_dst.append(i) for j in range(i + 1, min(walk_length, i + 1 + window_size)): idx_list_src.append(j) idx_list_dst.append(i) self.idx_list_src = torch.LongTensor(idx_list_src) self.idx_list_dst = torch.LongTensor(idx_list_dst)
[docs] def reset_parameters(self): """Reinitialize learnable parameters""" init_range = 1.0 / self.emb_dim init.uniform_(self.node_embed.weight.data, -init_range, init_range) init.constant_(self.context_embed.weight.data, 0)
def sample(self, indices): """Sample random walks Parameters ---------- indices : torch.Tensor Nodes from which we perform random walk Returns ------- torch.Tensor Random walks in the form of node ID sequences. The Tensor is of shape :attr:`(len(indices), walk_length)`. """ return random_walk(self.g, indices, length=self.walk_length - 1)[0]
[docs] def forward(self, batch_walk): """Compute the loss for the batch of random walks Parameters ---------- batch_walk : torch.Tensor Random walks in the form of node ID sequences. The Tensor is of shape :attr:`(batch_size, walk_length)`. Returns ------- torch.Tensor Loss value """ batch_size = len(batch_walk) device = batch_walk.device batch_node_embed = self.node_embed(batch_walk).view(-1, self.emb_dim) batch_context_embed = self.context_embed(batch_walk).view(-1, self.emb_dim) batch_idx_list_offset = torch.arange(batch_size) * self.walk_length batch_idx_list_offset = batch_idx_list_offset.unsqueeze(1) idx_list_src = batch_idx_list_offset + self.idx_list_src.unsqueeze(0) idx_list_dst = batch_idx_list_offset + self.idx_list_dst.unsqueeze(0) idx_list_src = idx_list_src.view(-1).to(device) idx_list_dst = idx_list_dst.view(-1).to(device) pos_src_emb = batch_node_embed[idx_list_src] pos_dst_emb = batch_context_embed[idx_list_dst] neg_idx_list_src = idx_list_dst.unsqueeze(1) + torch.zeros( self.negative_size).unsqueeze(0).to(device) neg_idx_list_src = neg_idx_list_src.view(-1) neg_src_emb = batch_node_embed[neg_idx_list_src.long()] if self.fast_neg: neg_idx_list_dst = list(range(batch_size * self.walk_length)) \ * (self.negative_size * self.window_size * 2) random.shuffle(neg_idx_list_dst) neg_idx_list_dst = neg_idx_list_dst[:len(neg_idx_list_src)] neg_idx_list_dst = torch.LongTensor(neg_idx_list_dst).to(device) neg_dst_emb = batch_context_embed[neg_idx_list_dst] else: neg_dst = choice(self.g.num_nodes(), size=len(neg_src_emb), prob=self.neg_prob) neg_dst_emb = self.context_embed(neg_dst.to(device)) pos_score = torch.sum(torch.mul(pos_src_emb, pos_dst_emb), dim=1) pos_score = torch.clamp(pos_score, max=6, min=-6) pos_score = torch.mean(-F.logsigmoid(pos_score)) neg_score = torch.sum(torch.mul(neg_src_emb, neg_dst_emb), dim=1) neg_score = torch.clamp(neg_score, max=6, min=-6) neg_score = torch.mean(-F.logsigmoid(-neg_score)) * self.negative_size * self.neg_weight return torch.mean(pos_score + neg_score)
[docs]class MetaPath2Vec(nn.Module): r"""metapath2vec module from `metapath2vec: Scalable Representation Learning for Heterogeneous Networks <https://dl.acm.org/doi/pdf/10.1145/3097983.3098036>`__ To achieve efficient optimization, we leverage the negative sampling technique for the training process. Repeatedly for each node in meta-path, we treat it as the center node and sample nearby positive nodes within context size and draw negative samples among all types of nodes from all meta-paths. Then we can use the center-context paired nodes and context-negative paired nodes to update the network. Parameters ---------- g : DGLGraph Graph for learning node embeddings. Two different canonical edge types :attr:`(utype, etype, vtype)` are not allowed to have same :attr:`etype`. metapath : list[str] A sequence of edge types in the form of a string. It defines a new edge type by composing multiple edge types in order. Note that the start node type and the end one are commonly the same. window_size : int In a random walk :attr:`w`, a node :attr:`w[j]` is considered close to a node :attr:`w[i]` if :attr:`i - window_size <= j <= i + window_size`. emb_dim : int, optional Size of each embedding vector. Default: 128 negative_size : int, optional Number of negative samples to use for each positive sample. Default: 5 sparse : bool, optional If True, gradients with respect to the learnable weights will be sparse. Default: True Attributes ---------- node_embed : nn.Embedding Embedding table of all nodes local_to_global_nid : dict[str, list] Mapping from type-specific node IDs to global node IDs Examples -------- >>> import torch >>> import dgl >>> from torch.optim import SparseAdam >>> from torch.utils.data import DataLoader >>> from dgl.nn.pytorch import MetaPath2Vec >>> # Define a model >>> g = dgl.heterograph({ ... ('user', 'uc', 'company'): dgl.rand_graph(100, 1000).edges(), ... ('company', 'cp', 'product'): dgl.rand_graph(100, 1000).edges(), ... ('company', 'cu', 'user'): dgl.rand_graph(100, 1000).edges(), ... ('product', 'pc', 'company'): dgl.rand_graph(100, 1000).edges() ... }) >>> model = MetaPath2Vec(g, ['uc', 'cu'], window_size=1) >>> # Use the source node type of etype 'uc' >>> dataloader = DataLoader(torch.arange(g.num_nodes('user')), batch_size=128, ... shuffle=True, collate_fn=model.sample) >>> optimizer = SparseAdam(model.parameters(), lr=0.025) >>> for (pos_u, pos_v, neg_v) in dataloader: ... loss = model(pos_u, pos_v, neg_v) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step() >>> # Get the embeddings of all user nodes >>> user_nids = torch.LongTensor(model.local_to_global_nid['user']) >>> user_emb = model.node_embed(user_nids) """ def __init__(self, g, metapath, window_size, emb_dim=128, negative_size=5, sparse=True): super().__init__() assert len(metapath) + 1 >= window_size, \ f'Expect len(metapath) >= window_size - 1, got {metapath} and {window_size}' self.hg = g self.emb_dim = emb_dim self.metapath = metapath self.window_size = window_size self.negative_size = negative_size # convert edge metapath to node metapath # get initial source node type src_type, _, _ = g.to_canonical_etype(metapath[0]) node_metapath = [src_type] for etype in metapath: _, _, dst_type = g.to_canonical_etype(etype) node_metapath.append(dst_type) self.node_metapath = node_metapath # Convert the graph into a homogeneous one for global to local node ID mapping g = to_homogeneous(g) # Convert it back to the hetero one for local to global node ID mapping hg = to_heterogeneous(g, self.hg.ntypes, self.hg.etypes) local_to_global_nid = hg.ndata[NID] for key, val in local_to_global_nid.items(): local_to_global_nid[key] = list(val.cpu().numpy()) self.local_to_global_nid = local_to_global_nid num_nodes_total = hg.num_nodes() node_frequency = torch.zeros(num_nodes_total) # random walk for idx in tqdm.trange(hg.num_nodes(node_metapath[0])): traces, _ = random_walk(g=hg, nodes=[idx], metapath=metapath) for tr in traces.cpu().numpy(): tr_nids = [ self.local_to_global_nid[node_metapath[i]][tr[i]] for i in range(len(tr))] node_frequency[torch.LongTensor(tr_nids)] += 1 neg_prob = node_frequency.pow(0.75) self.neg_prob = neg_prob / neg_prob.sum() # center node embedding self.node_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse) self.context_embed = nn.Embedding(num_nodes_total, self.emb_dim, sparse=sparse) self.reset_parameters()
[docs] def reset_parameters(self): """Reinitialize learnable parameters""" init_range = 1.0 / self.emb_dim init.uniform_(self.node_embed.weight.data, -init_range, init_range) init.constant_(self.context_embed.weight.data, 0)
def sample(self, indices): """Sample positive and negative samples Parameters ---------- indices : torch.Tensor Node IDs of the source node type from which we perform random walks Returns ------- torch.Tensor Positive center nodes torch.Tensor Positive context nodes torch.Tensor Negative context nodes """ traces, _ = random_walk(g=self.hg, nodes=indices, metapath=self.metapath) u_list = [] v_list = [] for tr in traces.cpu().numpy(): tr_nids = [ self.local_to_global_nid[self.node_metapath[i]][tr[i]] for i in range(len(tr))] for i, u in enumerate(tr_nids): for j, v in enumerate(tr_nids[max(i - self.window_size, 0):i + self.window_size]): if i == j: continue u_list.append(u) v_list.append(v) neg_v = choice(self.hg.num_nodes(), size=len(u_list) * self.negative_size, prob=self.neg_prob).reshape(len(u_list), self.negative_size) return torch.LongTensor(u_list), torch.LongTensor(v_list), neg_v
[docs] def forward(self, pos_u, pos_v, neg_v): r"""Compute the loss for the batch of positive and negative samples Parameters ---------- pos_u : torch.Tensor Positive center nodes pos_v : torch.Tensor Positive context nodes neg_v : torch.Tensor Negative context nodes Returns ------- torch.Tensor Loss value """ emb_u = self.node_embed(pos_u) emb_v = self.context_embed(pos_v) emb_neg_v = self.context_embed(neg_v) score = torch.sum(torch.mul(emb_u, emb_v), dim=1) score = torch.clamp(score, max=10, min=-10) score = -F.logsigmoid(score) neg_score = torch.bmm(emb_neg_v, emb_u.unsqueeze(2)).squeeze() neg_score = torch.clamp(neg_score, max=10, min=-10) neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1) return torch.mean(score + neg_score)