"""Node embedding optimizers for distributed training"""
import abc
import warnings
from abc import abstractmethod
from os.path import exists
import torch as th
import dgl
from .... import backend as F
from ...dist_tensor import DistTensor
from ...graph_partition_book import EDGE_PART_POLICY, NODE_PART_POLICY
from ...nn.pytorch import DistEmbedding
from .utils import alltoall_cpu, alltoallv_cpu
EMB_STATES = "emb_states"
WORLD_SIZE = "world_size"
IDS = "ids"
PARAMS = "params"
STATES = "states"
class DistSparseGradOptimizer(abc.ABC):
r"""The abstract dist sparse optimizer.
Note: dgl dist sparse optimizer only work with dgl.distributed.DistEmbedding
Parameters
----------
params : list of DistEmbedding
The list of DistEmbedding.
lr : float
The learning rate.
"""
def __init__(self, params, lr):
self._params = params
self._lr = lr
self._rank = None
self._world_size = None
self._shared_cache = {}
self._clean_grad = False
self._opt_meta = {}
self._state = {}
## collect all hyper parameters for save
self._defaults = {}
if th.distributed.is_initialized():
self._rank = th.distributed.get_rank()
self._world_size = th.distributed.get_world_size()
else:
self._rank = 0
self._world_size = 1
def local_state_dict(self):
"""Return the state pertaining to current rank of the optimizer.
Returns
-------
dict
Local state dict
Example Dict of Adagrad Optimizer:
.. code-block:: json
{
"params": {
"_lr": 0.01,
"_eps": "1e-8",
"world_size": 2
},
"emb_states": {
"emb_name1": {
"ids": [0, 2, 4, 6 ,8 ,10], ## tensor,
"emb_name1_sum": [0.1 , 0.2, 0.5, 0.1, 0.2] ## tensor,
},
"emb_name2": {
"ids": [0, 2, 4, 6 ,8 ,10], ## tensor,
"emb_name2_sum": [0.3 , 0.2, 0.4, 0.5, 0.2] ## tensor,
}
}
}
:param json: json object
See Also
--------
load_local_state_dict
"""
local_state_dict = {}
local_state_dict[EMB_STATES] = {}
local_state_dict[PARAMS] = {WORLD_SIZE: self._world_size}
for emb in self._params:
trainers_per_machine = self._world_size // max(
1, dgl.distributed.get_num_machines()
)
emb_state_dict = {}
part_policy = (
emb.part_policy if emb.part_policy else emb.weight.part_policy
)
idx = self._get_local_ids(part_policy)
if trainers_per_machine > 1:
kv_idx_split = (idx % trainers_per_machine).long()
local_rank = self._rank % trainers_per_machine
mask = kv_idx_split == local_rank
idx = F.boolean_mask(idx, mask)
emb_state_dict.update({IDS: idx})
emb_state = {}
states = (
list(self._state[emb.name])
if isinstance(self._state[emb.name], tuple)
else [self._state[emb.name]]
)
emb_state = {state.name: state[idx] for state in states}
emb_state_dict.update({STATES: emb_state})
local_state_dict[EMB_STATES].update({emb.name: emb_state_dict})
local_state_dict[PARAMS].update(self._defaults)
return local_state_dict
def load_local_state_dict(self, local_state_dict):
"""Load the local state from the input state_dict,
updating the optimizer as needed.
Parameters
----------
local_state_dict : dict
Optimizer state; should be an object returned
from a call to local_state_dict().
See Also
--------
local_state_dict
"""
for emb_name, emb_state in local_state_dict[EMB_STATES].items():
idx = emb_state[IDS]
# As state of an embedding of different optimizers can be a single
# DistTensor(Adagrad) or a tuple(Adam) of that, converting it to list for
# consistency. The list contains reference(s) to original DistTensor(s).
states = (
list(self._state[emb_name])
if isinstance(self._state[emb_name], tuple)
else [self._state[emb_name]]
)
if len(emb_state[STATES]) != len(states):
raise ValueError(
f"loaded state dict has a different number of states"
f" of embedding {emb_name}"
)
name_to_index = {
state.name: index for index, state in enumerate(states)
}
for name, state in emb_state[STATES].items():
if name not in name_to_index:
raise ValueError(
"loaded state dict contains a state {name}"
"that can't be found in the optimizer states"
)
state_idx = name_to_index[name]
state = state.to(
th.device("cpu"), states[name_to_index[name]].dtype
)
states[state_idx][idx] = state
self._defaults.update(local_state_dict[PARAMS])
self.__dict__.update(local_state_dict[PARAMS])
def save(self, f):
"""Save the local state_dict to disk on per rank.
Saved dict contains 2 parts:
* 'params': hyper parameters of the optimizer.
* 'emb_states': partial optimizer states, each embedding contains 2 items:
1. ```ids```: global id of the nodes/edges stored in this rank.
2. ```states```: state data corrseponding to ```ids```.
NOTE: This needs to be called on all ranks.
Parameters
----------
f : Union[str, os.PathLike]
The path of the file to save to.
See Also
--------
load
"""
if self._world_size > 1:
th.distributed.barrier()
f = f if isinstance(f, str) else str(f, "UTF-8")
f = f"{f}_{self._rank}"
th.save(self.local_state_dict(), f)
if self._world_size > 1:
th.distributed.barrier()
def load(self, f):
"""Load the local state of the optimizer from the file on per rank.
NOTE: This needs to be called on all ranks.
Parameters
----------
f : Union[str, os.PathLike]
The path of the file to load from.
See Also
--------
save
"""
if self._world_size > 1:
th.distributed.barrier()
f = f if isinstance(f, str) else str(f, "UTF-8")
f_attach_rank = f"{f}_{self._rank}"
# Don't throw error here to support device number scale-out
# after reloading, but make sure your hyper parameter is same
# as before because new added local optimizers will be filled
# in nothing
if not exists(f_attach_rank):
warnings.warn(f"File {f_attach_rank} can't be found, load nothing.")
else:
old_world_size = self._load_state_from(f_attach_rank)
# Device number scale-in
if self._world_size < old_world_size:
for rank in range(
self._rank + self._world_size,
old_world_size,
self._world_size,
):
self._load_state_from(f"{f}_{rank}")
if self._world_size > 1:
th.distributed.barrier()
def _load_state_from(self, f):
local_state_dict = th.load(f)
world_size = local_state_dict[PARAMS].pop(WORLD_SIZE)
self.load_local_state_dict(local_state_dict)
return world_size
def _get_local_ids(self, part_policy):
if EDGE_PART_POLICY in part_policy.policy_str:
return part_policy.partition_book.partid2eids(
part_policy.part_id, part_policy.type_name
)
elif NODE_PART_POLICY in part_policy.policy_str:
return part_policy._partition_book.partid2nids(
part_policy.part_id, part_policy.type_name
)
else:
raise RuntimeError(
"Cannot support policy: %s " % part_policy.policy_str
)
def step(self):
"""The step function.
The step function is invoked at the end of every batch to push the gradients
of the embeddings involved in a mini-batch to DGL's servers and update the embeddings.
"""
with th.no_grad():
local_indics = {emb.name: [] for emb in self._params}
local_grads = {emb.name: [] for emb in self._params}
device = th.device("cpu")
for emb in self._params:
name = emb.weight.name
kvstore = emb.weight.kvstore
trainers_per_server = self._world_size // kvstore.num_servers
idics = []
grads = []
for trace in emb._trace:
if trace[1].grad is not None:
idics.append(trace[0])
grads.append(trace[1].grad.data)
else:
assert len(trace[0]) == 0
# If the sparse embedding is not used in the previous forward step
# The idx and grad will be empty, initialize them as empty tensors to
# avoid crashing the optimizer step logic.
#
# Note: we cannot skip the gradient exchange and update steps as other
# working processes may send gradient update requests corresponding
# to certain embedding to this process.
idics = (
th.cat(idics, dim=0)
if len(idics) != 0
else th.zeros((0,), dtype=th.long, device=th.device("cpu"))
)
grads = (
th.cat(grads, dim=0)
if len(grads) != 0
else th.zeros(
(0, emb.embedding_dim),
dtype=th.float32,
device=th.device("cpu"),
)
)
device = grads.device
# will send grad to each corresponding trainer
if self._world_size > 1:
# get idx split from kvstore
idx_split = kvstore.get_partid(emb.data_name, idics)
idx_split_size = []
idics_list = []
grad_list = []
# split idx and grad first
for i in range(kvstore.num_servers):
mask = idx_split == i
idx_i = idics[mask]
grad_i = grads[mask]
if trainers_per_server <= 1:
idx_split_size.append(
th.tensor([idx_i.shape[0]], dtype=th.int64)
)
idics_list.append(idx_i)
grad_list.append(grad_i)
else:
kv_idx_split = th.remainder(
idx_i, trainers_per_server
).long()
for j in range(trainers_per_server):
mask = kv_idx_split == j
idx_j = idx_i[mask]
grad_j = grad_i[mask]
idx_split_size.append(
th.tensor([idx_j.shape[0]], dtype=th.int64)
)
idics_list.append(idx_j)
grad_list.append(grad_j)
# if one machine launch multiple KVServer, they share the same storage.
# For each machine, the pytorch rank is num_trainers *
# machine_id + i
# use scatter to sync across trainers about the p2p tensor size
# Note: If we have GPU nccl support, we can use all_to_all to
# sync information here
gather_list = list(
th.empty([self._world_size], dtype=th.int64).chunk(
self._world_size
)
)
alltoall_cpu(
self._rank,
self._world_size,
gather_list,
idx_split_size,
)
# use cpu until we have GPU alltoallv
idx_gather_list = [
th.empty((int(num_emb),), dtype=idics.dtype)
for num_emb in gather_list
]
alltoallv_cpu(
self._rank,
self._world_size,
idx_gather_list,
idics_list,
)
local_indics[name] = idx_gather_list
grad_gather_list = [
th.empty(
(int(num_emb), grads.shape[1]), dtype=grads.dtype
)
for num_emb in gather_list
]
alltoallv_cpu(
self._rank,
self._world_size,
grad_gather_list,
grad_list,
)
local_grads[name] = grad_gather_list
else:
local_indics[name] = [idics]
local_grads[name] = [grads]
if self._clean_grad:
# clean gradient track
for emb in self._params:
emb.reset_trace()
self._clean_grad = False
# do local update
for emb in self._params:
name = emb.weight.name
idx = th.cat(local_indics[name], dim=0)
grad = th.cat(local_grads[name], dim=0)
self.update(
idx.to(device, non_blocking=True),
grad.to(device, non_blocking=True),
emb,
)
# synchronized gradient update
if self._world_size > 1:
th.distributed.barrier()
@abstractmethod
def update(self, idx, grad, emb):
"""Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. We maintain gradient states for
each embedding so they can be updated separately.
Parameters
----------
idx : tensor
Index of the embeddings to be updated.
grad : tensor
Gradient of each embedding.
emb : dgl.distributed.DistEmbedding
Sparse node embedding to update.
"""
def zero_grad(self):
"""clean grad cache"""
self._clean_grad = True
def initializer(shape, dtype):
"""Sparse optimizer state initializer
Parameters
----------
shape : tuple of ints
The shape of the state tensor
dtype : torch dtype
The data type of the state tensor
"""
arr = th.zeros(shape, dtype=dtype)
return arr
[docs]class SparseAdagrad(DistSparseGradOptimizer):
r"""Distributed Node embedding optimizer using the Adagrad algorithm.
This optimizer implements a distributed sparse version of Adagrad algorithm for
optimizing :class:`dgl.distributed.DistEmbedding`. Being sparse means it only updates
the embeddings whose gradients have updates, which are usually a very
small portion of the total embeddings.
Adagrad maintains a :math:`G_{t,i,j}` for every parameter in the embeddings, where
:math:`G_{t,i,j}=G_{t-1,i,j} + g_{t,i,j}^2` and :math:`g_{t,i,j}` is the gradient of
the dimension :math:`j` of embedding :math:`i` at step :math:`t`.
NOTE: The support of sparse Adagrad optimizer is experimental.
Parameters
----------
params : list[dgl.distributed.DistEmbedding]
The list of dgl.distributed.DistEmbedding.
lr : float
The learning rate.
eps : float, Optional
The term added to the denominator to improve numerical stability
Default: 1e-10
"""
def __init__(self, params, lr, eps=1e-10):
super(SparseAdagrad, self).__init__(params, lr)
self._eps = eps
self._defaults = {"_lr": lr, "_eps": eps}
# We need to register a state sum for each embedding in the kvstore.
for emb in params:
assert isinstance(
emb, DistEmbedding
), "SparseAdagrad only supports dgl.distributed.DistEmbedding"
name = emb.name + "_sum"
state = DistTensor(
(emb.num_embeddings, emb.embedding_dim),
th.float32,
name,
init_func=initializer,
part_policy=emb.part_policy,
is_gdata=False,
)
assert (
emb.name not in self._state
), "{} already registered in the optimizer".format(emb.name)
self._state[emb.name] = state
def update(self, idx, grad, emb):
"""Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. We maintain gradient states for
each embedding so they can be updated separately.
Parameters
----------
idx : tensor
Index of the embeddings to be updated.
grad : tensor
Gradient of each embedding.
emb : dgl.distributed.DistEmbedding
Sparse embedding to update.
"""
eps = self._eps
clr = self._lr
state_dev = th.device("cpu")
exec_dev = grad.device
# only perform async copies cpu -> gpu, or gpu-> gpu, but block
# when copying to the cpu, so as to ensure the copy is finished
# before operating on the data on the cpu
state_block = state_dev == th.device("cpu") and exec_dev != state_dev
# the update is non-linear so indices must be unique
grad_indices, inverse, cnt = th.unique(
idx, return_inverse=True, return_counts=True
)
grad_values = th.zeros(
(grad_indices.shape[0], grad.shape[1]), device=exec_dev
)
grad_values.index_add_(0, inverse, grad)
grad_values = grad_values / cnt.unsqueeze(1)
grad_sum = grad_values * grad_values
# update grad state
grad_state = self._state[emb.name][grad_indices].to(exec_dev)
grad_state += grad_sum
grad_state_dst = grad_state.to(state_dev, non_blocking=True)
if state_block:
# use events to try and overlap CPU and GPU as much as possible
update_event = th.cuda.Event()
update_event.record()
# update emb
std_values = grad_state.sqrt_().add_(eps)
tmp = clr * grad_values / std_values
tmp_dst = tmp.to(state_dev, non_blocking=True)
if state_block:
std_event = th.cuda.Event()
std_event.record()
# wait for our transfers from exec_dev to state_dev to finish
# before we can use them
update_event.wait()
self._state[emb.name][grad_indices] = grad_state_dst
if state_block:
# wait for the transfer of std_values to finish before we
# can use it
std_event.wait()
emb._tensor[grad_indices] -= tmp_dst
[docs]class SparseAdam(DistSparseGradOptimizer):
r"""Distributed Node embedding optimizer using the Adam algorithm.
This optimizer implements a distributed sparse version of Adam algorithm for
optimizing :class:`dgl.distributed.DistEmbedding`. Being sparse means it only updates
the embeddings whose gradients have updates, which are usually a very
small portion of the total embeddings.
Adam maintains a :math:`Gm_{t,i,j}` and `Gp_{t,i,j}` for every parameter
in the embeddings, where
:math:`Gm_{t,i,j}=beta1 * Gm_{t-1,i,j} + (1-beta1) * g_{t,i,j}`,
:math:`Gp_{t,i,j}=beta2 * Gp_{t-1,i,j} + (1-beta2) * g_{t,i,j}^2`,
:math:`g_{t,i,j} = lr * Gm_{t,i,j} / (1 - beta1^t) / \sqrt{Gp_{t,i,j} / (1 - beta2^t)}` and
:math:`g_{t,i,j}` is the gradient of the dimension :math:`j` of embedding :math:`i`
at step :math:`t`.
NOTE: The support of sparse Adam optimizer is experimental.
Parameters
----------
params : list[dgl.distributed.DistEmbedding]
The list of dgl.distributed.DistEmbedding.
lr : float
The learning rate.
betas : tuple[float, float], Optional
Coefficients used for computing running averages of gradient and its square.
Default: (0.9, 0.999)
eps : float, Optional
The term added to the denominator to improve numerical stability
Default: 1e-8
"""
def __init__(self, params, lr, betas=(0.9, 0.999), eps=1e-08):
super(SparseAdam, self).__init__(params, lr)
self._eps = eps
# We need to register a state sum for each embedding in the kvstore.
self._beta1 = betas[0]
self._beta2 = betas[1]
self._defaults = {
"_lr": lr,
"_eps": eps,
"_beta1": betas[0],
"_beta2": betas[1],
}
for emb in params:
assert isinstance(
emb, DistEmbedding
), "SparseAdam only supports dgl.distributed.DistEmbedding"
state_step = DistTensor(
(emb.num_embeddings,),
th.float32,
emb.name + "_step",
init_func=initializer,
part_policy=emb.part_policy,
is_gdata=False,
)
state_mem = DistTensor(
(emb.num_embeddings, emb.embedding_dim),
th.float32,
emb.name + "_mem",
init_func=initializer,
part_policy=emb.part_policy,
is_gdata=False,
)
state_power = DistTensor(
(emb.num_embeddings, emb.embedding_dim),
th.float32,
emb.name + "_power",
init_func=initializer,
part_policy=emb.part_policy,
is_gdata=False,
)
state = (state_step, state_mem, state_power)
assert (
emb.name not in self._state
), "{} already registered in the optimizer".format(emb.name)
self._state[emb.name] = state
def update(self, idx, grad, emb):
"""Update embeddings in a sparse manner
Sparse embeddings are updated in mini batches. We maintain gradient states for
each embedding so they can be updated separately.
Parameters
----------
idx : tensor
Index of the embeddings to be updated.
grad : tensor
Gradient of each embedding.
emb : dgl.distributed.DistEmbedding
Sparse embedding to update.
"""
beta1 = self._beta1
beta2 = self._beta2
eps = self._eps
clr = self._lr
state_step, state_mem, state_power = self._state[emb.name]
state_dev = th.device("cpu")
exec_dev = grad.device
# only perform async copies cpu -> gpu, or gpu-> gpu, but block
# when copying to the cpu, so as to ensure the copy is finished
# before operating on the data on the cpu
state_block = state_dev == th.device("cpu") and exec_dev != state_dev
# the update is non-linear so indices must be unique
grad_indices, inverse, cnt = th.unique(
idx, return_inverse=True, return_counts=True
)
# update grad state
state_idx = grad_indices.to(state_dev)
# The original implementation will cause read/write contension.
# state_step[state_idx] += 1
# state_step = state_step[state_idx].to(exec_dev, non_blocking=True)
# In a distributed environment, the first line of code will send write requests to
# kvstore servers to update the state_step which is asynchronous and the second line
# of code will also send read requests to kvstore servers. The write and read requests
# may be handled by different kvstore servers managing the same portion of the
# state_step dist tensor in the same node. So that, the read request may read an old
# value (i.e., 0 in the first iteration) which will cause
# update_power_corr to be NaN
state_val = state_step[state_idx] + 1
state_step[state_idx] = state_val
state_step = state_val.to(exec_dev)
orig_mem = state_mem[state_idx].to(exec_dev)
orig_power = state_power[state_idx].to(exec_dev)
grad_values = th.zeros(
(grad_indices.shape[0], grad.shape[1]), device=exec_dev
)
grad_values.index_add_(0, inverse, grad)
grad_values = grad_values / cnt.unsqueeze(1)
grad_mem = grad_values
grad_power = grad_values * grad_values
update_mem = beta1 * orig_mem + (1.0 - beta1) * grad_mem
update_power = beta2 * orig_power + (1.0 - beta2) * grad_power
update_mem_dst = update_mem.to(state_dev, non_blocking=True)
update_power_dst = update_power.to(state_dev, non_blocking=True)
if state_block:
# use events to try and overlap CPU and GPU as much as possible
update_event = th.cuda.Event()
update_event.record()
update_mem_corr = update_mem / (
1.0 - th.pow(th.tensor(beta1, device=exec_dev), state_step)
).unsqueeze(1)
update_power_corr = update_power / (
1.0 - th.pow(th.tensor(beta2, device=exec_dev), state_step)
).unsqueeze(1)
std_values = clr * update_mem_corr / (th.sqrt(update_power_corr) + eps)
std_values_dst = std_values.to(state_dev, non_blocking=True)
if state_block:
std_event = th.cuda.Event()
std_event.record()
# wait for our transfers from exec_dev to state_dev to finish
# before we can use them
update_event.wait()
state_mem[state_idx] = update_mem_dst
state_power[state_idx] = update_power_dst
if state_block:
# wait for the transfer of std_values to finish before we
# can use it
std_event.wait()
emb._tensor[state_idx] -= std_values_dst