TWIRLSUnfoldingAndAttention

class dgl.nn.pytorch.conv.TWIRLSUnfoldingAndAttention(d, alp, lam, prop_step, attn_aft=-1, tau=0.2, T=-1, p=1, use_eta=False, init_att=False, attn_dropout=0, precond=True)[source]

Bases: Module

Description

Combine propagation and attention together.

param d:

Size of graph feature.

type d:

int

param alp:

Step size. \(\alpha\) in ther paper.

type alp:

float

param lam:

Coefficient of graph smooth term. \(\lambda\) in ther paper.

type lam:

int

param prop_step:

Number of propagation steps

type prop_step:

int

param attn_aft:

Where to put attention layer. i.e. number of propagation steps before attention. If set to -1, then no attention.

type attn_aft:

int

param tau:

The lower thresholding parameter. Correspond to \(\tau\) in the paper.

type tau:

float

param T:

The upper thresholding parameter. Correspond to \(T\) in the paper.

type T:

float

param p:

Correspond to \(\rho\) in the paper..

type p:

float

param use_eta:

If True, learn a weight vector for each dimension when doing attention.

type use_eta:

bool

param init_att:

If True, add an extra attention layer before propagation.

type init_att:

bool

param attn_dropout:

the dropout rate of attention value. Default: 0.0.

type attn_dropout:

float

param precond:

If True, use pre-conditioned & reparameterized version propagation (eq.28), else use normalized laplacian (eq.30).

type precond:

bool

Example

>>> import dgl
>>> from dgl.nn import TWIRLSUnfoldingAndAttention
>>> import torch as th
>>> g = dgl.graph(([0, 1, 2, 3, 2, 5], [1, 2, 3, 4, 0, 3])).add_self_loop()
>>> feat = th.ones(6,5)
>>> prop = TWIRLSUnfoldingAndAttention(10, 1, 1, prop_step=3)
>>> res = prop(g,feat)
>>> res
tensor([[2.5000, 2.5000, 2.5000, 2.5000, 2.5000],
        [2.5000, 2.5000, 2.5000, 2.5000, 2.5000],
        [2.5000, 2.5000, 2.5000, 2.5000, 2.5000],
        [3.7656, 3.7656, 3.7656, 3.7656, 3.7656],
        [2.5217, 2.5217, 2.5217, 2.5217, 2.5217],
        [4.0000, 4.0000, 4.0000, 4.0000, 4.0000]])
forward(g, X)[source]

Description

Compute forward pass of propagation & attention.

param g:

The graph.

type g:

DGLGraph

param X:

Init features.

type X:

torch.Tensor

returns:

The graph.

rtype:

torch.Tensor