Skip to content

Commit

Permalink
add backend for heter training (#41526) (#41651)
Browse files Browse the repository at this point in the history
  • Loading branch information
lilong12 committed Apr 15, 2022
1 parent 9f2ae36 commit 86bbb0f
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion python/paddle/distributed/collective.py
Expand Up @@ -138,7 +138,7 @@ def _get_global_env():
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"

_valid_backend_list = ['nccl', 'gloo', 'hccl']
_valid_backend_list = ['nccl', 'gloo', 'hccl', 'heter']
_default_store = None # the default tcp store
_default_backend = None

Expand Down Expand Up @@ -234,6 +234,31 @@ def _new_process_group_impl(backend,
pg = core.ProcessGroupNCCL(store, rank, world_size, group_id)
elif backend == "hccl":
pg = core.ProcessGroupHCCL(store, rank, world_size, group_id)
elif backend == "heter":
cluster_id = int(os.getenv("CLUSTER_ID", "-1"))
assert cluster_id >= 0, "please set the CLUSTER_ID variable."
cluster_size = os.getenv("CLUSTER_SIZE", None)
assert cluster_size, "please set the CLUSTER_SIZE variable."
cluster_size = cluster_size.split(",")
cluster_size = [int(s) for s in cluster_size]
switch_ep = os.getenv("CLUSTER_SWITCH", None)
assert switch_ep, "please set the CLUSTER_SWITCH variable."
cluster_size_cumsum = np.cumsum(cluster_size)
cluster_offset = 0 if cluster_id == 0 else cluster_size_cumsum[
cluster_id - 1]
global_rank = cluster_offset + rank
global_world_size = cluster_size_cumsum[-1]
pg = core.ProcessGroupHeter(
store,
rank=global_rank,
world_size=global_world_size,
gid=0,
local_rank=rank,
local_size=world_size,
gloo_rank=cluster_id,
gloo_size=len(cluster_size),
with_switch=True,
switch_endpoint=switch_ep)

return pg

Expand Down

0 comments on commit 86bbb0f

Please sign in to comment.