Skip to content

Commit

Permalink
[dask] Deterministic rank assignment. (#8018)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 11, 2022
1 parent 20d1bba commit 36e7c53
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 19 deletions.
18 changes: 15 additions & 3 deletions python-package/xgboost/dask.py
Expand Up @@ -170,9 +170,11 @@ def _try_start_tracker(
use_logger=False,
)
else:
assert isinstance(addrs[0], str) or addrs[0] is None
addr = addrs[0]
assert isinstance(addr, str) or addr is None
host_ip = get_host_ip(addr)
rabit_context = RabitTracker(
host_ip=get_host_ip(addrs[0]), n_workers=n_workers, use_logger=False
host_ip=host_ip, n_workers=n_workers, use_logger=False, sortby="task"
)
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
Expand Down Expand Up @@ -222,8 +224,16 @@ class RabitContext(rabit.RabitContext):
def __init__(self, args: List[bytes]) -> None:
super().__init__(args)
worker = distributed.get_worker()
with distributed.worker_client() as client:
info = client.scheduler_info()
w = info["workers"][worker.address]
wid = w["id"]
# We use task ID for rank assignment which makes the RABIT rank consistent (but
# not the same as task ID is string and "10" is sorted before "2") with dask
# worker ID. This outsources the rank assignment to dask and prevents
# non-deterministic issue.
self.args.append(
("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode()
(f"DMLC_TASK_ID=[xgboost.dask-{wid}]:" + str(worker.address)).encode()
)


Expand Down Expand Up @@ -841,6 +851,8 @@ async def _get_rabit_args(
except Exception: # pylint: disable=broad-except
sched_addr = None

# make sure all workers are online so that we can obtain reliable scheduler_info
client.wait_for_workers(n_workers)
env = await client.run_on_scheduler(
_start_tracker, n_workers, sched_addr, user_addr
)
Expand Down
59 changes: 43 additions & 16 deletions python-package/xgboost/tracker.py
Expand Up @@ -32,15 +32,15 @@ def recvall(self, nbytes: int) -> bytes:
chunk = self.sock.recv(min(nbytes - nread, 1024))
nread += len(chunk)
res.append(chunk)
return b''.join(res)
return b"".join(res)

def recvint(self) -> int:
"""Receive an integer of 32 bytes"""
return struct.unpack('@i', self.recvall(4))[0]
return struct.unpack("@i", self.recvall(4))[0]

def sendint(self, value: int) -> None:
"""Send an integer of 32 bytes"""
self.sock.sendall(struct.pack('@i', value))
self.sock.sendall(struct.pack("@i", value))

def sendstr(self, value: str) -> None:
"""Send a Python string"""
Expand Down Expand Up @@ -69,6 +69,7 @@ def get_family(addr: str) -> int:

class WorkerEntry:
"""Hanlder to each worker."""

def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
worker = ExSocket(sock)
self.sock = worker
Expand All @@ -78,7 +79,7 @@ def __init__(self, sock: socket.socket, s_addr: Tuple[str, int]):
worker.sendint(MAGIC_NUM)
self.rank = worker.recvint()
self.world_size = worker.recvint()
self.jobid = worker.recvstr()
self.task_id = worker.recvstr()
self.cmd = worker.recvstr()
self.wait_accept = 0
self.port: Optional[int] = None
Expand All @@ -96,8 +97,8 @@ def decide_rank(self, job_map: Dict[str, int]) -> int:
"""Get the rank of current entry."""
if self.rank >= 0:
return self.rank
if self.jobid != 'NULL' and self.jobid in job_map:
return job_map[self.jobid]
if self.task_id != "NULL" and self.task_id in job_map:
return job_map[self.task_id]
return -1

def assign_rank(
Expand Down Expand Up @@ -180,7 +181,12 @@ class RabitTracker:
"""

def __init__(
self, host_ip: str, n_workers: int, port: int = 0, use_logger: bool = False
self,
host_ip: str,
n_workers: int,
port: int = 0,
use_logger: bool = False,
sortby: str = "host",
) -> None:
"""A Python implementation of RABIT tracker.
Expand All @@ -190,6 +196,13 @@ def __init__(
Use logging.info for tracker print command. When set to False, Python print
function is used instead.
sortby:
How to sort the workers for rank assignment. The default is host, but users
can set the `DMLC_TASK_ID` via RABIT initialization arguments and obtain
deterministic rank assignment. Available options are:
- host
- task
"""
sock = socket.socket(get_family(host_ip), socket.SOCK_STREAM)
sock.bind((host_ip, port))
Expand All @@ -200,6 +213,7 @@ def __init__(
self.thread: Optional[Thread] = None
self.n_workers = n_workers
self._use_logger = use_logger
self._sortby = sortby
logging.info("start listen on %s:%d", host_ip, self.port)

def __del__(self) -> None:
Expand All @@ -223,7 +237,7 @@ def worker_envs(self) -> Dict[str, Union[str, int]]:
get environment variables for workers
can be passed in as args or envs
"""
return {'DMLC_TRACKER_URI': self.host_ip, 'DMLC_TRACKER_PORT': self.port}
return {"DMLC_TRACKER_URI": self.host_ip, "DMLC_TRACKER_PORT": self.port}

def _get_tree(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int]]:
tree_map: _TreeMap = {}
Expand Down Expand Up @@ -296,8 +310,16 @@ def get_link_map(self, n_workers: int) -> Tuple[_TreeMap, Dict[int, int], _RingM
parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_

def _sort_pending(self, pending: List[WorkerEntry]) -> List[WorkerEntry]:
if self._sortby == "host":
pending.sort(key=lambda s: s.host)
elif self._sortby == "task":
pending.sort(key=lambda s: s.task_id)
return pending

def accept_workers(self, n_workers: int) -> None:
"""Wait for all workers to connect to the tracker."""

# set of nodes that finishes the job
shutdown: Dict[int, WorkerEntry] = {}
# set of nodes that is waiting for connections
Expand Down Expand Up @@ -341,27 +363,32 @@ def accept_workers(self, n_workers: int) -> None:
assert todo_nodes
pending.append(s)
if len(pending) == len(todo_nodes):
pending.sort(key=lambda x: x.host)
pending = self._sort_pending(pending)
for s in pending:
rank = todo_nodes.pop(0)
if s.jobid != 'NULL':
job_map[s.jobid] = rank
if s.task_id != "NULL":
job_map[s.task_id] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0:
wait_conn[rank] = s
logging.debug('Received %s signal from %s; assign rank %d',
s.cmd, s.host, s.rank)
logging.debug(
"Received %s signal from %s; assign rank %d",
s.cmd,
s.host,
s.rank,
)
if not todo_nodes:
logging.info('@tracker All of %d nodes getting started', n_workers)
logging.info("@tracker All of %d nodes getting started", n_workers)
else:
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
logging.debug('Received %s signal from %d', s.cmd, s.rank)
logging.debug("Received %s signal from %d", s.cmd, s.rank)
if s.wait_accept > 0:
wait_conn[rank] = s
logging.info('@tracker All nodes finishes job')
logging.info("@tracker All nodes finishes job")

def start(self, n_workers: int) -> None:
"""Strat the tracker, it will wait for `n_workers` to connect."""

def run() -> None:
self.accept_workers(n_workers)

Expand Down
32 changes: 32 additions & 0 deletions tests/python/test_tracker.py
Expand Up @@ -4,6 +4,7 @@
import testing as tm
import numpy as np
import sys
import re

if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
Expand Down Expand Up @@ -58,3 +59,34 @@ def test_rabit_ops():
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)


def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
from test_with_dask import _get_client_workers

def local_test(worker_id):
with xgb.dask.RabitContext(args):
for val in args:
sval = val.decode("utf-8")
if sval.startswith("DMLC_TASK_ID"):
task_id = sval
break
matched = re.search(".*-([0-9]).*", task_id)
rank = xgb.rabit.get_rank()
# As long as the number of workers is lesser than 10, rank and worker id
# should be the same
assert rank == int(matched.group(1))

with LocalCluster(n_workers=8) as cluster:
with Client(cluster) as client:
workers = _get_client_workers(client)
args = client.sync(
xgb.dask._get_rabit_args,
len(workers),
None,
client,
)

futures = client.map(local_test, range(len(workers)), workers=workers)
client.gather(futures)

0 comments on commit 36e7c53

Please sign in to comment.