MetaPath2Vec¶
-
class
dgl.nn.pytorch.
MetaPath2Vec
(g, metapath, window_size, emb_dim=128, negative_size=5, sparse=True)[source]¶ Bases:
torch.nn.modules.module.Module
metapath2vec module from metapath2vec: Scalable Representation Learning for Heterogeneous Networks
To achieve efficient optimization, we leverage the negative sampling technique for the training process. Repeatedly for each node in meta-path, we treat it as the center node and sample nearby positive nodes within context size and draw negative samples among all types of nodes from all meta-paths. Then we can use the center-context paired nodes and context-negative paired nodes to update the network.
- Parameters
g (DGLGraph) – Graph for learning node embeddings. Two different canonical edge types
(utype, etype, vtype)
are not allowed to have sameetype
.metapath (list[str]) – A sequence of edge types in the form of a string. It defines a new edge type by composing multiple edge types in order. Note that the start node type and the end one are commonly the same.
window_size (int) – In a random walk
w
, a nodew[j]
is considered close to a nodew[i]
ifi - window_size <= j <= i + window_size
.emb_dim (int, optional) – Size of each embedding vector. Default: 128
negative_size (int, optional) – Number of negative samples to use for each positive sample. Default: 5
sparse (bool, optional) – If True, gradients with respect to the learnable weights will be sparse. Default: True
-
node_embed
¶ Embedding table of all nodes
- Type
nn.Embedding
Examples
>>> import torch >>> import dgl >>> from torch.optim import SparseAdam >>> from torch.utils.data import DataLoader >>> from dgl.nn.pytorch import MetaPath2Vec
>>> # Define a model >>> g = dgl.heterograph({ ... ('user', 'uc', 'company'): dgl.rand_graph(100, 1000).edges(), ... ('company', 'cp', 'product'): dgl.rand_graph(100, 1000).edges(), ... ('company', 'cu', 'user'): dgl.rand_graph(100, 1000).edges(), ... ('product', 'pc', 'company'): dgl.rand_graph(100, 1000).edges() ... }) >>> model = MetaPath2Vec(g, ['uc', 'cu'], window_size=1)
>>> # Use the source node type of etype 'uc' >>> dataloader = DataLoader(torch.arange(g.num_nodes('user')), batch_size=128, ... shuffle=True, collate_fn=model.sample) >>> optimizer = SparseAdam(model.parameters(), lr=0.025)
>>> for (pos_u, pos_v, neg_v) in dataloader: ... loss = model(pos_u, pos_v, neg_v) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Get the embeddings of all user nodes >>> user_nids = torch.LongTensor(model.local_to_global_nid['user']) >>> user_emb = model.node_embed(user_nids)