Source code for dgl.data.superpixel

import os
import pickle

import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm

from .. import backend as F
from ..convert import graph as dgl_graph

from .dgl_dataset import DGLDataset
from .utils import download, extract_archive, load_graphs, save_graphs, Subset


def sigma(dists, kth=8):
    num_nodes = dists.shape[0]

    # Compute sigma and reshape.
    if kth > num_nodes:
        # Handling for graphs with num_nodes less than kth.
        sigma = np.array([1] * num_nodes).reshape(num_nodes, 1)
    else:
        # Get k-nearest neighbors for each node.
        knns = np.partition(dists, kth, axis=-1)[:, : kth + 1]
        sigma = knns.sum(axis=1).reshape((knns.shape[0], 1)) / kth

    return sigma + 1e-8


def compute_adjacency_matrix_images(coord, feat, use_feat=True):
    coord = coord.reshape(-1, 2)
    # Compute coordinate distance.
    c_dist = cdist(coord, coord)

    if use_feat:
        # Compute feature distance.
        f_dist = cdist(feat, feat)
        # Compute adjacency.
        A = np.exp(
            -((c_dist / sigma(c_dist)) ** 2) - (f_dist / sigma(f_dist)) ** 2
        )
    else:
        A = np.exp(-((c_dist / sigma(c_dist)) ** 2))

    # Convert to symmetric matrix.
    A = 0.5 * (A + A.T)
    A[np.diag_indices_from(A)] = 0
    return A


def compute_edges_list(A, kth=9):
    # Get k-similar neighbor indices for each node.
    num_nodes = A.shape[0]
    new_kth = num_nodes - kth

    if num_nodes > kth:
        knns = np.argpartition(A, new_kth - 1, axis=-1)[:, new_kth:-1]
        knn_values = np.partition(A, new_kth - 1, axis=-1)[:, new_kth:-1]
    else:
        # Handling for graphs with less than kth nodes.
        # In such cases, the resulting graph will be fully connected.
        knns = np.tile(np.arange(num_nodes), num_nodes).reshape(
            num_nodes, num_nodes
        )
        knn_values = A

        # Removing self loop.
        if num_nodes != 1:
            knn_values = A[knns != np.arange(num_nodes)[:, None]].reshape(
                num_nodes, -1
            )
            knns = knns[knns != np.arange(num_nodes)[:, None]].reshape(
                num_nodes, -1
            )
    return knns, knn_values


class SuperPixelDataset(DGLDataset):
    def __init__(
        self,
        raw_dir=None,
        name="MNIST",
        split="train",
        use_feature=False,
        force_reload=False,
        verbose=False,
        transform=None,
    ):
        assert split in ["train", "test"], "split not valid."
        assert name in ["MNIST", "CIFAR10"], "name not valid."

        self.use_feature = use_feature
        self.split = split
        self._dataset_name = name
        self.graphs = []
        self.labels = []

        super().__init__(
            name="Superpixel",
            raw_dir=raw_dir,
            url="""
            https://www.dropbox.com/s/y2qwa77a0fxem47/superpixels.zip?dl=1
            """,
            force_reload=force_reload,
            verbose=verbose,
            transform=transform,
        )

    @property
    def img_size(self):
        r"""Size of dataset image."""
        if self._dataset_name == "MNIST":
            return 28
        return 32

    @property
    def save_path(self):
        r"""Directory to save the processed dataset."""
        return os.path.join(self.raw_path, "processed")

    @property
    def raw_data_path(self):
        r"""Path to save the raw dataset file."""
        return os.path.join(self.raw_path, "superpixels.zip")

    @property
    def graph_path(self):
        r"""Path to save the processed dataset file."""
        if self.use_feature:
            return os.path.join(
                self.save_path,
                f"use_feat_{self._dataset_name}_{self.split}.pkl",
            )
        return os.path.join(
            self.save_path, f"{self._dataset_name}_{self.split}.pkl"
        )

    def download(self):
        path = download(self.url, path=self.raw_data_path)
        extract_archive(path, target_dir=self.raw_path, overwrite=True)

    def process(self):
        if self._dataset_name == "MNIST":
            plk_file = "mnist_75sp"
        elif self._dataset_name == "CIFAR10":
            plk_file = "cifar10_150sp"

        with open(
            os.path.join(
                self.raw_path, "superpixels", f"{plk_file}_{self.split}.pkl"
            ),
            "rb",
        ) as f:
            self.labels, self.sp_data = pickle.load(f)
            self.labels = F.tensor(self.labels)

        self.Adj_matrices = []
        self.node_features = []
        self.edges_lists = []
        self.edge_features = []

        for index, sample in enumerate(
            tqdm(self.sp_data, desc=f"Processing {self.split} dataset")
        ):
            mean_px, coord = sample[:2]
            coord = coord / self.img_size

            if self.use_feature:
                A = compute_adjacency_matrix_images(
                    coord, mean_px
                )  # using super-pixel locations + features
            else:
                A = compute_adjacency_matrix_images(
                    coord, mean_px, False
                )  # using only super-pixel locations
            edges_list, edge_values_list = compute_edges_list(A)

            N_nodes = A.shape[0]

            mean_px = mean_px.reshape(N_nodes, -1)
            coord = coord.reshape(N_nodes, 2)
            x = np.concatenate((mean_px, coord), axis=1)

            edge_values_list = edge_values_list.reshape(-1)

            self.node_features.append(x)
            self.edge_features.append(edge_values_list)
            self.Adj_matrices.append(A)
            self.edges_lists.append(edges_list)

        for index in tqdm(
            range(len(self.sp_data)), desc=f"Dump {self.split} dataset"
        ):
            N = self.node_features[index].shape[0]

            src_nodes = []
            dst_nodes = []
            for src, dsts in enumerate(self.edges_lists[index]):
                # handling for 1 node where the self loop would be the only edge
                if N == 1:
                    src_nodes.append(src)
                    dst_nodes.append(dsts)
                else:
                    dsts = dsts[dsts != src]
                    srcs = [src] * len(dsts)
                    src_nodes.extend(srcs)
                    dst_nodes.extend(dsts)

            src_nodes = F.tensor(src_nodes)
            dst_nodes = F.tensor(dst_nodes)

            g = dgl_graph((src_nodes, dst_nodes), num_nodes=N)
            g.ndata["feat"] = F.zerocopy_from_numpy(
                self.node_features[index]
            ).to(F.float32)
            g.edata["feat"] = (
                F.zerocopy_from_numpy(self.edge_features[index])
                .to(F.float32)
                .unsqueeze(1)
            )

            self.graphs.append(g)

    def load(self):
        self.graphs, label_dict = load_graphs(self.graph_path)
        self.labels = label_dict["labels"]

    def save(self):
        save_graphs(
            self.graph_path, self.graphs, labels={"labels": self.labels}
        )

    def has_cache(self):
        return os.path.exists(self.graph_path)

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        """Get the idx-th sample.

        Parameters
        ---------
        idx : int or tensor
            The sample index.
            1-D tensor as `idx` is allowed when transform is None.

        Returns
        -------
        (:class:`dgl.DGLGraph`, Tensor)
            Graph with node feature stored in ``feat`` field and its label.
        or
        :class:`dgl.data.utils.Subset`
            Subset of the dataset at specified indices
        """
        if F.is_tensor(idx) and idx.dim() == 1:
            if self._transform is None:
                return Subset(self, idx.cpu())

            raise ValueError(
                "Tensor idx not supported when transform is not None."
            )

        if self._transform is None:
            return self.graphs[idx], self.labels[idx]

        return self._transform(self.graphs[idx]), self.labels[idx]


[docs]class MNISTSuperPixelDataset(SuperPixelDataset): r"""MNIST superpixel dataset for the graph classification task. DGL dataset of MNIST and CIFAR10 in the benchmark-gnn which contains graphs converted fromt the original MINST and CIFAR10 images. Reference `<http://arxiv.org/abs/2003.00982>`_ Statistics: - Train examples: 60,000 - Test examples: 10,000 - Size of dataset images: 28 Parameters ---------- raw_dir : str Directory to store all the downloaded raw datasets. Default: "~/.dgl/". split : str Should be chosen from ["train", "test"] Default: "train". use_feature: bool - True: Adj matrix defined from super-pixel locations + features - False: Adj matrix defined from super-pixel locations (only) Default: False. force_reload : bool Whether to reload the dataset. Default: False. verbose : bool Whether to print out progress information. Default: False. transform : callable, optional A transform that takes in a :class:`~dgl.DGLGraph` object and returns a transformed version. The :class:`~dgl.DGLGraph` object will be transformed before every access. Examples --------- >>> from dgl.data import MNISTSuperPixelDataset >>> # MNIST dataset >>> train_dataset = MNISTSuperPixelDataset(split="train") >>> len(train_dataset) 60000 >>> graph, label = train_dataset[0] >>> graph Graph(num_nodes=71, num_edges=568, ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}) >>> # support tensor to be index when transform is None >>> # see details in __getitem__ function >>> import torch >>> idx = torch.tensor([0, 1, 2]) >>> train_dataset_subset = train_dataset[idx] >>> train_dataset_subset[0] Graph(num_nodes=71, num_edges=568, ndata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}) """ def __init__( self, raw_dir=None, split="train", use_feature=False, force_reload=False, verbose=False, transform=None, ): super().__init__( raw_dir=raw_dir, name="MNIST", split=split, use_feature=use_feature, force_reload=force_reload, verbose=verbose, transform=transform, )
[docs]class CIFAR10SuperPixelDataset(SuperPixelDataset): r"""CIFAR10 superpixel dataset for the graph classification task. DGL dataset of CIFAR10 in the benchmark-gnn which contains graphs converted fromt the original CIFAR10 images. Reference `<http://arxiv.org/abs/2003.00982>`_ Statistics: - Train examples: 50,000 - Test examples: 10,000 - Size of dataset images: 32 Parameters ---------- raw_dir : str Directory to store all the downloaded raw datasets. Default: "~/.dgl/". split : str Should be chosen from ["train", "test"] Default: "train". use_feature: bool - True: Adj matrix defined from super-pixel locations + features - False: Adj matrix defined from super-pixel locations (only) Default: False. force_reload : bool Whether to reload the dataset. Default: False. verbose : bool Whether to print out progress information. Default: False. transform : callable, optional A transform that takes in a :class:`~dgl.DGLGraph` object and returns a transformed version. The :class:`~dgl.DGLGraph` object will be transformed before every access. Examples --------- >>> from dgl.data import CIFAR10SuperPixelDataset >>> # CIFAR10 dataset >>> train_dataset = CIFAR10SuperPixelDataset(split="train") >>> len(train_dataset) 50000 >>> graph, label = train_dataset[0] >>> graph Graph(num_nodes=123, num_edges=984, ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}), >>> # support tensor to be index when transform is None >>> # see details in __getitem__ function >>> import torch >>> idx = torch.tensor([0, 1, 2]) >>> train_dataset_subset = train_dataset[idx] >>> train_dataset_subset[0] Graph(num_nodes=123, num_edges=984, ndata_schemes={'feat': Scheme(shape=(5,), dtype=torch.float32)} edata_schemes={'feat': Scheme(shape=(1,), dtype=torch.float32)}), """ def __init__( self, raw_dir=None, split="train", use_feature=False, force_reload=False, verbose=False, transform=None, ): super().__init__( raw_dir=raw_dir, name="CIFAR10", split=split, use_feature=use_feature, force_reload=force_reload, verbose=verbose, transform=transform, )