8์žฅ: Mixed Precision ํ•™์Šตยถ

(English Version)

DGL์€ mixed precision ํ•™์Šต์„ ์œ„ํ•ด์„œ PyTorchโ€™s automatic mixed precision package ์™€ ํ˜ธํ™˜๋œ๋‹ค. ๋”ฐ๋ผ์„œ, ํ•™์Šต ์‹œ๊ฐ„ ๋ฐ GPU ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ ˆ์•ฝํ•  ์ˆ˜ ์žˆ๋‹ค. ์ด ๊ธฐ๋Šฅ์„ ํ™œ์„ฑํ™”ํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š”, PyTorch 1.6+, python 3.7+์„ ์„ค์น˜ํ•˜๊ณ , float16 ๋ฐ์ดํ„ฐ ํƒ€์ž… ์ง€์›์„ ์œ„ํ•ด์„œ DGL์„ ์†Œ์Šค ํŒŒ์ผ์„ ์‚ฌ์šฉํ•ด์„œ ๋นŒ๋“œํ•ด์•ผ ํ•œ๋‹ค. (์ด ๊ธฐ๋Šฅ์€ ์•„์ง ๋ฒ ํƒ€ ๋‹จ๊ณ„์ด๊ณ , pre-built pip wheel ํ˜•ํƒœ๋กœ ์ œ๊ณตํ•˜์ง€ ์•Š๋Š”๋‹ค.)

์„ค์น˜ยถ

์šฐ์„  DGL ์†Œ์Šค ์ฝ”๋“œ๋ฅผ GitHub์—์„œ ๋‹ค์šด๋กœ๋“œํ•˜๊ณ , USE_FP16=ON ํ”Œ๋ž˜๊ทธ๋ฅผ ์‚ฌ์šฉํ•ด์„œ shared library๋ฅผ ๋นŒ๋“œํ•œ๋‹ค.

git clone --recurse-submodules https://github.com/dmlc/dgl.git
cd dgl
mkdir build
cd build
cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
make -j

๋‹ค์Œ์œผ๋กœ Python ๋ฐ”์ธ๋”ฉ์„ ์„ค์น˜ํ•œ๋‹ค.

cd ../python
python setup.py install

Half precision์„ ์‚ฌ์šฉํ•œ ๋ฉ”์‹œ์ง€ ์ „๋‹ฌยถ

fp16์„ ์ง€์›ํ•˜๋Š” DGL์€ UDF(User Defined Function)์ด๋‚˜ ๋นŒํŠธ์ธ ํ•จ์ˆ˜(์˜ˆ, dgl.function.sum, dgl.function.copy_u)๋ฅผ ์‚ฌ์šฉํ•ด์„œ float16 ํ”ผ์ณ์— ๋Œ€ํ•œ ๋ฉ”์‹œ์ง€ ์ „๋‹ฌ์„ ํ—ˆ์šฉํ•œ๋‹ค.

๋‹ค์Œ ์˜ˆ์ œ๋Š” DGL ๋ฉ”์‹œ์ง€ ์ „๋‹ฌ API๋ฅผ half-precision ํ”ผ์ณ๋“ค์— ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค€๋‹ค.

>>> import torch
>>> import dgl
>>> import dgl.function as fn
>>> g = dgl.rand_graph(30, 100).to(0)  # Create a graph on GPU w/ 30 nodes and 100 edges.
>>> g.ndata['h'] = torch.rand(30, 16).to(0).half()  # Create fp16 node features.
>>> g.edata['w'] = torch.rand(100, 1).to(0).half()  # Create fp16 edge features.
>>> # Use DGL's built-in functions for message passing on fp16 features.
>>> g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'x'))
>>> g.ndata['x'][0]
tensor([0.3391, 0.2208, 0.7163, 0.6655, 0.7031, 0.5854, 0.9404, 0.7720, 0.6562,
        0.4028, 0.6943, 0.5908, 0.9307, 0.5962, 0.7827, 0.5034],
       device='cuda:0', dtype=torch.float16)
>>> g.apply_edges(fn.u_dot_v('h', 'x', 'hx'))
>>> g.edata['hx'][0]
tensor([5.4570], device='cuda:0', dtype=torch.float16)
>>> # Use UDF(User Defined Functions) for message passing on fp16 features.
>>> def message(edges):
...     return {'m': edges.src['h'] * edges.data['w']}
...
>>> def reduce(nodes):
...     return {'y': torch.sum(nodes.mailbox['m'], 1)}
...
>>> def dot(edges):
...     return {'hy': (edges.src['h'] * edges.dst['y']).sum(-1, keepdims=True)}
...
>>> g.update_all(message, reduce)
>>> g.ndata['y'][0]
tensor([0.3394, 0.2209, 0.7168, 0.6655, 0.7026, 0.5854, 0.9404, 0.7720, 0.6562,
        0.4028, 0.6943, 0.5908, 0.9307, 0.5967, 0.7827, 0.5039],
       device='cuda:0', dtype=torch.float16)
>>> g.apply_edges(dot)
>>> g.edata['hy'][0]
tensor([5.4609], device='cuda:0', dtype=torch.float16)

End-to-End Mixed Precision ํ•™์Šตยถ

DGL์€ PyTorch์˜ AMP package๋ฅผ ์‚ฌ์šฉํ•ด์„œ mixed precision ํ•™์Šต์„ ๊ตฌํ˜„ํ•˜๊ณ  ์žˆ์–ด์„œ, ์‚ฌ์šฉ ๋ฐฉ๋ฒ•์€ PyTorch์˜ ๊ฒƒ ๊ณผ ๋™์ผํ•˜๋‹ค.

GNN ๋ชจ๋ธ์˜ forward ํŒจ์Šค(loss ๊ณ„์‚ฐ ํฌํ•จ)๋ฅผ torch.cuda.amp.autocast() ๋กœ ๋ž˜ํ•‘ํ•˜๋ฉด PyTorch๋Š” ๊ฐ op ๋ฐ ํ…์„œ์— ๋Œ€ํ•ด์„œ ์ ์ ˆํ•œ ๋ฐ์ดํ„ฐ ํƒ€์ž…์„ ์ž๋™์œผ๋กœ ์„ ํƒํ•œ๋‹ค. Half precision ํ…์„œ๋Š” ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์ ์ด๊ณ , half precision ํ…์„œ์— ๋Œ€ํ•œ ๋Œ€๋ถ€๋ถ„ ์—ฐ์‚ฐ๋“ค์€ GPU tensorcore๋“ค์„ ํ™œ์šฉํ•˜๊ธฐ ๋•Œ๋ฌธ์— ๋” ๋น ๋ฅด๋‹ค.

float16 ํฌ๋ฉง์˜ ์ž‘์€ graident๋“ค์€ ์–ธ๋”ํ”Œ๋กœ์šฐ(underflow) ๋ฌธ์ œ๋ฅผ ๊ฐ–๋Š”๋ฐ (0์ด ๋˜๋ฒ„๋ฆผ), PyTorch๋Š” ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด์„œ GradScaler ๋ชจ๋“ˆ์„ ์ œ๊ณตํ•œ๋‹ค. GradScaler ๋Š” loss ๊ฐ’์— factor๋ฅผ ๊ณฑํ•˜๊ณ , ์ด scaled loss์— backward pass๋ฅผ ์ˆ˜ํ–‰ํ•œ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ํŒŒ๋ผ๋ฉ”ํ„ฐ๋“ค์„ ์—…๋ฐ์ดํŠธํ•˜๋Š” optimizer๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์ „์— unscale ํ•œ๋‹ค.

๋‹ค์Œ์€ 3-๋ ˆ์ด์–ด GAT๋ฅผ Reddit ๋ฐ์ดํ„ฐ์…‹(1140์–ต๊ฐœ์˜ ์—์ง€๋ฅผ ๊ฐ–๋Š”)์— ํ•™์Šต์„ ํ•˜๋Š” ์Šคํฌ๋ฆฝํŠธ์ด๋‹ค. use_fp16 ๊ฐ€ ํ™œ์„ฑํ™”/๋น„ํ™œ์„ฑํ™”๋˜์—ˆ์„ ๋•Œ์˜ ์ฝ”๋“œ ์ฐจ์ด๋ฅผ ์‚ดํŽด๋ณด์ž.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import dgl
from dgl.data import RedditDataset
from dgl.nn import GATConv

use_fp16 = True


class GAT(nn.Module):
    def __init__(self,
                 in_feats,
                 n_hidden,
                 n_classes,
                 heads):
        super().__init__()
        self.layers = nn.ModuleList()
        self.layers.append(GATConv(in_feats, n_hidden, heads[0], activation=F.elu))
        self.layers.append(GATConv(n_hidden * heads[0], n_hidden, heads[1], activation=F.elu))
        self.layers.append(GATConv(n_hidden * heads[1], n_classes, heads[2], activation=F.elu))

    def forward(self, g, h):
        for l, layer in enumerate(self.layers):
            h = layer(g, h)
            if l != len(self.layers) - 1:
                h = h.flatten(1)
            else:
                h = h.mean(1)
        return h

# Data loading
data = RedditDataset()
device = torch.device(0)
g = data[0]
g = dgl.add_self_loop(g)
g = g.int().to(device)
train_mask = g.ndata['train_mask']
features = g.ndata['feat']
labels = g.ndata['label']
in_feats = features.shape[1]
n_hidden = 256
n_classes = data.num_classes
n_edges = g.number_of_edges()
heads = [1, 1, 1]
model = GAT(in_feats, n_hidden, n_classes, heads)
model = model.to(device)

# Create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)
# Create gradient scaler
scaler = GradScaler()

for epoch in range(100):
    model.train()
    optimizer.zero_grad()

    # Wrap forward pass with autocast
    with autocast(enabled=use_fp16):
        logits = model(g, features)
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

    if use_fp16:
        # Backprop w/ gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        loss.backward()
        optimizer.step()

    print('Epoch {} | Loss {}'.format(epoch, loss.item()))

NVIDIA V100 (16GB) ํ•œ๊ฐœ๋ฅผ ๊ฐ–๋Š” ์ปดํ“จํ„ฐ์—์„œ, ์ด ๋ชจ๋ธ์„ fp16์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ํ•™์Šตํ•  ๋•Œ๋Š” 15.2GB GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์‚ฌ์šฉ๋˜๋Š”๋ฐ, fp16์„ ํ™œ์„ฑํ™”ํ•˜๋ฉด, ํ•™์Šต์— 12.8G GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์‚ฌ์šฉ๋œ๋ฉฐ, ๋‘ ๊ฒฝ์šฐ loss๊ฐ€ ๋น„์Šทํ•œ ๊ฐ’์œผ๋กœ ์ˆ˜๋ ดํ•œ๋‹ค. ๋งŒ์•ฝ head์˜ ๊ฐฏ์ˆ˜๋ฅผ [2, 2, 2] ๋กœ ๋ฐ”๊พธ๋ฉด, fp16๋ฅผ ์‚ฌ์šฉํ•˜์ง€ ์•Š๋Š” ํ•™์Šต์€ GPU OOM(out-of-memory) ์ด์Šˆ๊ฐ€ ์ƒ๊ธธ ๊ฒƒ์ด์ง€๋งŒ, fp16๋ฅผ ์‚ฌ์šฉํ•œ ํ•™์Šต์€ 15.7G GPU ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด์„œ ์ˆ˜ํ–‰๋œ๋‹ค.

DGL์€ half-precision ์ง€์›์„ ๊ณ„์† ํ–ฅ์ƒํ•˜๊ณ  ์žˆ๊ณ , ์—ฐ์‚ฐ ์ปค๋„์˜ ์„ฑ๋Šฅ์€ ์•„์ง ์ตœ์ ์€ ์•„๋‹ˆ๋‹ค. ์•ž์œผ๋กœ์˜ ์—…๋ฐ์ดํŠธ๋ฅผ ๊ณ„์† ์ง€์ผœ๋ณด์ž.