From 8f026ca3f31e6e843f5a55e38798474ebe0514c1 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 1 Nov 2022 21:47:41 -0700 Subject: [PATCH] Add back xgboost.rabit for backwards compatibility (#8408) * Add back xgboost.rabit for backwards compatibility * fix my errors * Fix lint * Use FutureWarning Co-authored-by: Hyunsu Philip Cho --- python-package/xgboost/__init__.py | 2 +- python-package/xgboost/rabit.py | 168 +++++++++++++++++++++++++++++ tests/python/test_collective.py | 31 ++++++ 3 files changed, 200 insertions(+), 1 deletion(-) create mode 100644 python-package/xgboost/rabit.py diff --git a/python-package/xgboost/__init__.py b/python-package/xgboost/__init__.py index 220093b47c4c..f17ac23ba61c 100644 --- a/python-package/xgboost/__init__.py +++ b/python-package/xgboost/__init__.py @@ -4,7 +4,7 @@ """ from . import tracker # noqa -from . import collective, dask +from . import collective, dask, rabit from .core import ( Booster, DataIter, diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py new file mode 100644 index 000000000000..258ec4b6da4e --- /dev/null +++ b/python-package/xgboost/rabit.py @@ -0,0 +1,168 @@ +"""Compatibility shim for xgboost.rabit; to be removed in 2.0""" +import logging +import warnings +from enum import IntEnum, unique +from typing import Any, TypeVar, Callable, Optional, List + +import numpy as np + +from . import collective + +LOGGER = logging.getLogger("[xgboost.rabit]") + + +def _deprecation_warning() -> str: + return ( + "The xgboost.rabit submodule is marked as deprecated in 1.7 and will be removed " + "in 2.0. Please use xgboost.collective instead." + ) + + +def init(args: Optional[List[bytes]] = None) -> None: + """Initialize the rabit library with arguments""" + warnings.warn(_deprecation_warning(), FutureWarning) + parsed = {} + if args: + for arg in args: + kv = arg.decode().split('=') + if len(kv) == 2: + parsed[kv[0]] = kv[1] + collective.init(**parsed) + + +def finalize() -> None: + """Finalize the process, notify tracker everything is done.""" + collective.finalize() + + +def get_rank() -> int: + """Get rank of current process. + Returns + ------- + rank : int + Rank of current process. + """ + return collective.get_rank() + + +def get_world_size() -> int: + """Get total number workers. + Returns + ------- + n : int + Total number of process. + """ + return collective.get_world_size() + + +def is_distributed() -> int: + """If rabit is distributed.""" + return collective.is_distributed() + + +def tracker_print(msg: Any) -> None: + """Print message to the tracker. + This function can be used to communicate the information of + the progress to the tracker + Parameters + ---------- + msg : str + The message to be printed to tracker. + """ + collective.communicator_print(msg) + + +def get_processor_name() -> bytes: + """Get the processor name. + Returns + ------- + name : str + the name of processor(host) + """ + return collective.get_processor_name().encode() + + +T = TypeVar("T") # pylint:disable=invalid-name + + +def broadcast(data: T, root: int) -> T: + """Broadcast object from one node to all other nodes. + Parameters + ---------- + data : any type that can be pickled + Input data, if current rank does not equal root, this can be None + root : int + Rank of the node to broadcast data from. + Returns + ------- + object : int + the result of broadcast. + """ + return collective.broadcast(data, root) + + +@unique +class Op(IntEnum): + """Supported operations for rabit.""" + MAX = 0 + MIN = 1 + SUM = 2 + OR = 3 + + +def allreduce( # pylint:disable=invalid-name + data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None +) -> np.ndarray: + """Perform allreduce, return the result. + Parameters + ---------- + data : + Input data. + op : + Reduction operators, can be MIN, MAX, SUM, BITOR + prepare_fun : + Lazy preprocessing function, if it is not None, prepare_fun(data) + will be called by the function before performing allreduce, to initialize the data + If the result of Allreduce can be recovered directly, + then prepare_fun will NOT be called + Returns + ------- + result : + The result of allreduce, have same shape as data + Notes + ----- + This function is not thread-safe. + """ + if prepare_fun is None: + return collective.allreduce(data, collective.Op(op)) + raise Exception("preprocessing function is no longer supported") + + +def version_number() -> int: + """Returns version number of current stored model. + This means how many calls to CheckPoint we made so far. + Returns + ------- + version : int + Version number of currently stored model + """ + return 0 + + +class RabitContext: + """A context controlling rabit initialization and finalization.""" + + def __init__(self, args: List[bytes] = None) -> None: + if args is None: + args = [] + self.args = args + + def __enter__(self) -> None: + init(self.args) + assert is_distributed() + LOGGER.warning(_deprecation_warning()) + LOGGER.debug("-------------- rabit say hello ------------------") + + def __exit__(self, *args: List) -> None: + finalize() + LOGGER.debug("--------------- rabit say bye ------------------") diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index f7de0400d21f..32b0a67a76e3 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -39,6 +39,37 @@ def test_rabit_communicator(): assert worker.exitcode == 0 +# TODO(rongou): remove this once we remove the rabit api. +def run_rabit_api_worker(rabit_env, world_size): + with xgb.rabit.RabitContext(rabit_env): + assert xgb.rabit.get_world_size() == world_size + assert xgb.rabit.is_distributed() + assert xgb.rabit.get_processor_name().decode() == socket.gethostname() + ret = xgb.rabit.broadcast('test1234', 0) + assert str(ret) == 'test1234' + ret = xgb.rabit.allreduce(np.asarray([1, 2, 3]), xgb.rabit.Op.SUM) + assert np.array_equal(ret, np.asarray([2, 4, 6])) + + +# TODO(rongou): remove this once we remove the rabit api. +def test_rabit_api(): + world_size = 2 + tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size) + tracker.start(world_size) + rabit_env = [] + for k, v in tracker.worker_envs().items(): + rabit_env.append(f"{k}={v}".encode()) + workers = [] + for _ in range(world_size): + worker = multiprocessing.Process(target=run_rabit_api_worker, + args=(rabit_env, world_size)) + workers.append(worker) + worker.start() + for worker in workers: + worker.join() + assert worker.exitcode == 0 + + def run_federated_worker(port, world_size, rank): with xgb.collective.CommunicatorContext(xgboost_communicator='federated', federated_server_address=f'localhost:{port}',