Multi-GPU Node Classification

This tutorial shows how to train a multi-layer GraphSAGE for node classification on the ogbn-products dataset provided by Open Graph Benchmark (OGB). The dataset contains around 2.4 million nodes and 62 million edges.

Open In Colab GitHub

By the end of this tutorial, you will be able to

  • Train a GNN model for node classification on multiple GPUs with DGL’s neighbor sampling components. After learning how to use multiple GPUs, you will be able to extend it to other scenarios such as link prediction.

Install DGL package and other dependencies

[1]:
# Install required packages.
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

# Install the CUDA version. If you want to install CPU version, please
# refer to https://www.dgl.ai/pages/start.html.
!pip install --pre dgl -f https://data.dgl.ai/wheels-test/cu121/repo.html
!pip install torchmetrics multiprocess

try:
    import dgl
    import dgl.graphbolt as gb
    installed = True
except ImportError as error:
    installed = False
    print(error)
print("DGL installed!" if installed else "DGL not found!")
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in links: https://data.dgl.ai/wheels-test/cu121/repo.html
Requirement already satisfied: dgl in /localscratch/dgl-3/python (2.1)
Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (1.24.4)
Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (1.11.4)
Requirement already satisfied: networkx>=2.1 in /usr/local/lib/python3.10/dist-packages (from dgl) (2.6.3)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (2.31.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from dgl) (4.66.1)
Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (5.9.4)
Requirement already satisfied: torchdata>=0.5.0 in /usr/local/lib/python3.10/dist-packages (from dgl) (0.7.0a0)
Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (3.3.2)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (3.6)
Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (1.26.18)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.19.0->dgl) (2023.11.17)
Requirement already satisfied: torch>=2 in /usr/local/lib/python3.10/dist-packages (from torchdata>=0.5.0->dgl) (2.2.0a0+81ea7a4)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.13.1)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (4.8.0)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (1.12)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (3.1.2)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=2->torchdata>=0.5.0->dgl) (2023.12.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=2->torchdata>=0.5.0->dgl) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=2->torchdata>=0.5.0->dgl) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python -m pip install --upgrade pip
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Requirement already satisfied: torchmetrics in /usr/local/lib/python3.10/dist-packages (1.3.0.post0)
Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (0.70.16)
Requirement already satisfied: numpy>1.20.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (1.24.4)
Requirement already satisfied: packaging>17.1 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (23.2)
Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (2.2.0a0+81ea7a4)
Requirement already satisfied: lightning-utilities>=0.8.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics) (0.10.1)
Requirement already satisfied: dill>=0.3.8 in /usr/local/lib/python3.10/dist-packages (from multiprocess) (0.3.8)
Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (68.2.2)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from lightning-utilities>=0.8.0->torchmetrics) (4.8.0)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.13.1)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (1.12)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (2.6.3)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (3.1.2)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->torchmetrics) (2023.12.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->torchmetrics) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->torchmetrics) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

[notice] A new release of pip is available: 23.3.1 -> 24.0
[notice] To update, run: python -m pip install --upgrade pip
DGL installed!

Defining Neighbor Sampler and Data Loader in DGL

The major difference from the previous tutorial is that we will use DistributedItemSampler instead of ItemSampler to sample mini-batches of nodes. DistributedItemSampler is a distributed version of ItemSampler that works with DistributedDataParallel. It is implemented as a wrapper around ItemSampler and will sample the same minibatch on all replicas. It also supports dropping the last non-full minibatch to avoid the need for padding.

[2]:
def create_dataloader(graph, features, itemset, device, is_train):
    datapipe = gb.DistributedItemSampler(
        item_set=itemset,
        batch_size=1024,
        drop_last=is_train,
        shuffle=is_train,
        drop_uneven_inputs=is_train,
    )
    datapipe = datapipe.copy_to(device)
    # Now that we have moved to device, sample_neighbor and fetch_feature steps
    # will be executed on GPUs.
    datapipe = datapipe.sample_neighbor(graph, [10, 10, 10])
    datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
    return gb.DataLoader(datapipe)

Weighted reduction across GPUs

As the different GPUs might process differing numbers of data points, we define a function to compute the exact average of values such as loss or accuracy in a weighted manner.

[3]:
import torch.distributed as dist

def weighted_reduce(tensor, weight, dst=0):
    ########################################################################
    # (HIGHLIGHT) Collect accuracy and loss values from sub-processes and
    # obtain overall average values.
    #
    # `torch.distributed.reduce` is used to reduce tensors from all the
    # sub-processes to a specified process, ReduceOp.SUM is used by default.
    #
    # Because the GPUs may have differing numbers of processed items, we
    # perform a weighted mean to calculate the exact loss and accuracy.
    ########################################################################
    dist.reduce(tensor=tensor, dst=dst)
    weight = torch.tensor(weight, device=tensor.device)
    dist.reduce(tensor=weight, dst=dst)
    return tensor / weight

Defining Model

Let’s consider training a 3-layer GraphSAGE with neighbor sampling. The model can be written as follows:

[4]:
from torch import nn
import torch.nn.functional as F
from dgl.nn import SAGEConv

class SAGE(nn.Module):
    def __init__(self, in_size, hidden_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # Three-layer GraphSAGE-mean.
        self.layers.append(SAGEConv(in_size, hidden_size, "mean"))
        self.layers.append(SAGEConv(hidden_size, hidden_size, "mean"))
        self.layers.append(SAGEConv(hidden_size, out_size, "mean"))
        self.dropout = nn.Dropout(0.5)
        self.hidden_size = hidden_size
        self.out_size = out_size
        # Set the dtype for the layers manually.
        self.float()

    def forward(self, blocks, x):
        hidden_x = x
        for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
            hidden_x = layer(block, hidden_x)
            is_last_layer = layer_idx == len(self.layers) - 1
            if not is_last_layer:
                hidden_x = F.relu(hidden_x)
                hidden_x = self.dropout(hidden_x)
        return hidden_x

Evaluation function

The evaluation function can be used to calculate the validation accuracy during training or the testing accuracy at the end of the training. The difference from the previous tutorial is that we need to return the number of items processed by each GPU to take a weighted average.

[5]:
import torchmetrics.functional as MF
import tqdm

@torch.no_grad()
def evaluate(rank, model, graph, features, itemset, num_classes, device):
    model.eval()
    y = []
    y_hats = []
    dataloader = create_dataloader(
        graph,
        features,
        itemset,
        device,
        is_train=False,
    )

    for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:
        blocks = data.blocks
        x = data.node_features["feat"]
        y.append(data.labels)
        y_hats.append(model.module(blocks, x))

    res = MF.accuracy(
        torch.cat(y_hats),
        torch.cat(y),
        task="multiclass",
        num_classes=num_classes,
    )

    return res.to(device), sum(y_i.size(0) for y_i in y)

Training Loop

The training loop is almost identical to the previous tutorial. In this tutorial, we explicitly disable uneven inputs coming from the dataloader, however, the Join Context Manager could be used to train possibly with incomplete batches at the end of epochs. Please refer to this tutorial for more information.

[6]:
import time

def train(
    rank,
    graph,
    features,
    train_set,
    valid_set,
    num_classes,
    model,
    device,
):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    # Create training data loader.
    dataloader = create_dataloader(
        graph,
        features,
        train_set,
        device,
        is_train=True,
    )

    for epoch in range(5):
        epoch_start = time.time()

        model.train()
        total_loss = torch.tensor(0, dtype=torch.float, device=device)
        num_train_items = 0
        for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:
            # The input features are from the source nodes in the first
            # layer's computation graph.
            x = data.node_features["feat"]

            # The ground truth labels are from the destination nodes
            # in the last layer's computation graph.
            y = data.labels

            blocks = data.blocks

            y_hat = model(blocks, x)

            # Compute loss.
            loss = F.cross_entropy(y_hat, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.detach() * y.size(0)
            num_train_items += y.size(0)

        # Evaluate the model.
        if rank == 0:
            print("Validating...")
        acc, num_val_items = evaluate(
            rank,
            model,
            graph,
            features,
            valid_set,
            num_classes,
            device,
        )
        total_loss = weighted_reduce(total_loss, num_train_items)
        acc = weighted_reduce(acc * num_val_items, num_val_items)

        # We synchronize before measuring the epoch time.
        torch.cuda.synchronize()
        epoch_end = time.time()
        if rank == 0:
            print(
                f"Epoch {epoch:05d} | "
                f"Average Loss {total_loss.item():.4f} | "
                f"Accuracy {acc.item():.4f} | "
                f"Time {epoch_end - epoch_start:.4f}"
            )

Defining Training and Evaluation Procedures

The following code defines the main function for each process. It is similar to the previous tutorial except that we need to initialize a distributed training context with torch.distributed and wrap the model with torch.nn.parallel.DistributedDataParallel.

[7]:
def run(rank, world_size, devices, dataset):
    # Set up multiprocessing environment.
    device = devices[rank]
    torch.cuda.set_device(device)
    dist.init_process_group(
        backend="nccl",  # Use NCCL backend for distributed GPU training
        init_method="tcp://127.0.0.1:12345",
        world_size=world_size,
        rank=rank,
    )

    # Pin the graph and features in-place to enable GPU access.
    graph = dataset.graph.pin_memory_()
    features = dataset.feature.pin_memory_()
    train_set = dataset.tasks[0].train_set
    valid_set = dataset.tasks[0].validation_set
    num_classes = dataset.tasks[0].metadata["num_classes"]

    in_size = features.size("node", None, "feat")[0]
    hidden_size = 256
    out_size = num_classes

    # Create GraphSAGE model. It should be copied onto a GPU as a replica.
    model = SAGE(in_size, hidden_size, out_size).to(device)
    model = nn.parallel.DistributedDataParallel(model)

    # Model training.
    if rank == 0:
        print("Training...")
    train(
        rank,
        graph,
        features,
        train_set,
        valid_set,
        num_classes,
        model,
        device,
    )

    # Test the model.
    if rank == 0:
        print("Testing...")
    test_set = dataset.tasks[0].test_set
    test_acc, num_test_items = evaluate(
        rank,
        model,
        graph,
        features,
        itemset=test_set,
        num_classes=num_classes,
        device=device,
    )
    test_acc = weighted_reduce(test_acc * num_test_items, num_test_items)

    if rank == 0:
        print(f"Test Accuracy {test_acc.item():.4f}")

Spawning Trainer Processes

The following code spawns a process for each GPU and calls the run function defined above.

[8]:
import torch.multiprocessing as mp

def main():
    if not torch.cuda.is_available():
        print("No GPU found!")
        return

    devices = [
        torch.device(f"cuda:{i}") for i in range(torch.cuda.device_count())
    ][:1]
    world_size = len(devices)

    print(f"Training with {world_size} gpus.")

    # Load and preprocess dataset.
    dataset = gb.BuiltinDataset("ogbn-products").load()

    # Thread limiting to avoid resource competition.
    os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // world_size)

    if world_size > 1:
        # The following launch method is not supported in a notebook.
        mp.set_sharing_strategy("file_system")
        mp.spawn(
            run,
            args=(world_size, devices, dataset),
            nprocs=world_size,
            join=True,
        )
    else:
        run(0, 1, devices, dataset)


if __name__ == "__main__":
    main()
Training with 1 gpus.
The dataset is already preprocessed.
Training...
192it [00:09, 21.32it/s]
Validating...
39it [00:00, 78.32it/s]
Epoch 00000 | Average Loss 1.2953 | Accuracy 0.8556 | Time 9.5520
192it [00:03, 61.08it/s]
Validating...
39it [00:00, 79.10it/s]
Epoch 00001 | Average Loss 0.5859 | Accuracy 0.8788 | Time 3.6609
192it [00:03, 62.82it/s]
Validating...
39it [00:00, 80.55it/s]
Epoch 00002 | Average Loss 0.4858 | Accuracy 0.8852 | Time 3.5646
192it [00:03, 60.34it/s]
Validating...
39it [00:00, 44.41it/s]
Epoch 00003 | Average Loss 0.4407 | Accuracy 0.8920 | Time 4.0852
192it [00:03, 58.87it/s]
Validating...
39it [00:00, 78.52it/s]
Epoch 00004 | Average Loss 0.4122 | Accuracy 0.8943 | Time 3.7938
Testing...
2162it [00:24, 89.75it/s]
Test Accuracy 0.7514