MetaPath2Vec

class dgl.nn.pytorch.MetaPath2Vec(g, metapath, window_size, emb_dim=128, negative_size=5, sparse=True)[source]

Bases: torch.nn.modules.module.Module

metapath2vec module from metapath2vec: Scalable Representation Learning for Heterogeneous Networks

To achieve efficient optimization, we leverage the negative sampling technique for the training process. Repeatedly for each node in meta-path, we treat it as the center node and sample nearby positive nodes within context size and draw negative samples among all types of nodes from all meta-paths. Then we can use the center-context paired nodes and context-negative paired nodes to update the network.

Parameters
  • g (DGLGraph) – Graph for learning node embeddings. Two different canonical edge types (utype, etype, vtype) are not allowed to have same etype.

  • metapath (list[str]) – A sequence of edge types in the form of a string. It defines a new edge type by composing multiple edge types in order. Note that the start node type and the end one are commonly the same.

  • window_size (int) – In a random walk w, a node w[j] is considered close to a node w[i] if i - window_size <= j <= i + window_size.

  • emb_dim (int, optional) – Size of each embedding vector. Default: 128

  • negative_size (int, optional) – Number of negative samples to use for each positive sample. Default: 5

  • sparse (bool, optional) – If True, gradients with respect to the learnable weights will be sparse. Default: True

node_embed

Embedding table of all nodes

Type

nn.Embedding

local_to_global_nid

Mapping from type-specific node IDs to global node IDs

Type

dict[str, list]

Examples

>>> import torch
>>> import dgl
>>> from torch.optim import SparseAdam
>>> from torch.utils.data import DataLoader
>>> from dgl.nn.pytorch import MetaPath2Vec
>>> # Define a model
>>> g = dgl.heterograph({
...     ('user', 'uc', 'company'): dgl.rand_graph(100, 1000).edges(),
...     ('company', 'cp', 'product'): dgl.rand_graph(100, 1000).edges(),
...     ('company', 'cu', 'user'): dgl.rand_graph(100, 1000).edges(),
...     ('product', 'pc', 'company'): dgl.rand_graph(100, 1000).edges()
... })
>>> model = MetaPath2Vec(g, ['uc', 'cu'], window_size=1)
>>> # Use the source node type of etype 'uc'
>>> dataloader = DataLoader(torch.arange(g.num_nodes('user')), batch_size=128,
...                         shuffle=True, collate_fn=model.sample)
>>> optimizer = SparseAdam(model.parameters(), lr=0.025)
>>> for (pos_u, pos_v, neg_v) in dataloader:
...     loss = model(pos_u, pos_v, neg_v)
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Get the embeddings of all user nodes
>>> user_nids = torch.LongTensor(model.local_to_global_nid['user'])
>>> user_emb = model.node_embed(user_nids)
forward(pos_u, pos_v, neg_v)[source]

Compute the loss for the batch of positive and negative samples

Parameters
  • pos_u (torch.Tensor) – Positive center nodes

  • pos_v (torch.Tensor) – Positive context nodes

  • neg_v (torch.Tensor) – Negative context nodes

Returns

Loss value

Return type

torch.Tensor

reset_parameters()[source]

Reinitialize learnable parameters