HeteroEmbeddingΒΆ
-
class
dgl.nn.pytorch.
HeteroEmbedding
(num_embeddings, embedding_dim)[source]ΒΆ Bases:
torch.nn.modules.module.Module
Create a heterogeneous embedding table.
It internally contains multiple
torch.nn.Embedding
with different dictionary sizes.- Parameters
Examples
>>> import dgl >>> import torch >>> from dgl.nn import HeteroEmbedding
>>> layer = HeteroEmbedding({'user': 2, ('user', 'follows', 'user'): 3}, 4) >>> # Get the heterogeneous embedding table >>> embeds = layer.weight >>> print(embeds['user'].shape) torch.Size([2, 4]) >>> print(embeds[('user', 'follows', 'user')].shape) torch.Size([3, 4])
>>> # Get the embeddings for a subset >>> input_ids = {'user': torch.LongTensor([0]), ... ('user', 'follows', 'user'): torch.LongTensor([0, 2])} >>> embeds = layer(input_ids) >>> print(embeds['user'].shape) torch.Size([1, 4]) >>> print(embeds[('user', 'follows', 'user')].shape) torch.Size([2, 4])