Skip to content

Commit

Permalink
Add back xgboost.rabit for backwards compatibility (#8408) (#8411)
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Nov 2, 2022
1 parent 4bc59ef commit 5b76acc
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python-package/xgboost/__init__.py
Expand Up @@ -4,7 +4,7 @@
"""

from . import tracker # noqa
from . import collective, dask
from . import collective, dask, rabit
from .core import (
Booster,
DataIter,
Expand Down
168 changes: 168 additions & 0 deletions python-package/xgboost/rabit.py
@@ -0,0 +1,168 @@
"""Compatibility shim for xgboost.rabit; to be removed in 2.0"""
import logging
import warnings
from enum import IntEnum, unique
from typing import Any, TypeVar, Callable, Optional, List

import numpy as np

from . import collective

LOGGER = logging.getLogger("[xgboost.rabit]")


def _deprecation_warning() -> str:
return (
"The xgboost.rabit submodule is marked as deprecated in 1.7 and will be removed "
"in 2.0. Please use xgboost.collective instead."
)


def init(args: Optional[List[bytes]] = None) -> None:
"""Initialize the rabit library with arguments"""
warnings.warn(_deprecation_warning(), FutureWarning)
parsed = {}
if args:
for arg in args:
kv = arg.decode().split('=')
if len(kv) == 2:
parsed[kv[0]] = kv[1]
collective.init(**parsed)


def finalize() -> None:
"""Finalize the process, notify tracker everything is done."""
collective.finalize()


def get_rank() -> int:
"""Get rank of current process.
Returns
-------
rank : int
Rank of current process.
"""
return collective.get_rank()


def get_world_size() -> int:
"""Get total number workers.
Returns
-------
n : int
Total number of process.
"""
return collective.get_world_size()


def is_distributed() -> int:
"""If rabit is distributed."""
return collective.is_distributed()


def tracker_print(msg: Any) -> None:
"""Print message to the tracker.
This function can be used to communicate the information of
the progress to the tracker
Parameters
----------
msg : str
The message to be printed to tracker.
"""
collective.communicator_print(msg)


def get_processor_name() -> bytes:
"""Get the processor name.
Returns
-------
name : str
the name of processor(host)
"""
return collective.get_processor_name().encode()


T = TypeVar("T") # pylint:disable=invalid-name


def broadcast(data: T, root: int) -> T:
"""Broadcast object from one node to all other nodes.
Parameters
----------
data : any type that can be pickled
Input data, if current rank does not equal root, this can be None
root : int
Rank of the node to broadcast data from.
Returns
-------
object : int
the result of broadcast.
"""
return collective.broadcast(data, root)


@unique
class Op(IntEnum):
"""Supported operations for rabit."""
MAX = 0
MIN = 1
SUM = 2
OR = 3


def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
) -> np.ndarray:
"""Perform allreduce, return the result.
Parameters
----------
data :
Input data.
op :
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun :
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to initialize the data
If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called
Returns
-------
result :
The result of allreduce, have same shape as data
Notes
-----
This function is not thread-safe.
"""
if prepare_fun is None:
return collective.allreduce(data, collective.Op(op))
raise Exception("preprocessing function is no longer supported")


def version_number() -> int:
"""Returns version number of current stored model.
This means how many calls to CheckPoint we made so far.
Returns
-------
version : int
Version number of currently stored model
"""
return 0


class RabitContext:
"""A context controlling rabit initialization and finalization."""

def __init__(self, args: List[bytes] = None) -> None:
if args is None:
args = []
self.args = args

def __enter__(self) -> None:
init(self.args)
assert is_distributed()
LOGGER.warning(_deprecation_warning())
LOGGER.debug("-------------- rabit say hello ------------------")

def __exit__(self, *args: List) -> None:
finalize()
LOGGER.debug("--------------- rabit say bye ------------------")
31 changes: 31 additions & 0 deletions tests/python/test_collective.py
Expand Up @@ -39,6 +39,37 @@ def test_rabit_communicator():
assert worker.exitcode == 0


# TODO(rongou): remove this once we remove the rabit api.
def run_rabit_api_worker(rabit_env, world_size):
with xgb.rabit.RabitContext(rabit_env):
assert xgb.rabit.get_world_size() == world_size
assert xgb.rabit.is_distributed()
assert xgb.rabit.get_processor_name().decode() == socket.gethostname()
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.rabit.allreduce(np.asarray([1, 2, 3]), xgb.rabit.Op.SUM)
assert np.array_equal(ret, np.asarray([2, 4, 6]))


# TODO(rongou): remove this once we remove the rabit api.
def test_rabit_api():
world_size = 2
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
tracker.start(world_size)
rabit_env = []
for k, v in tracker.worker_envs().items():
rabit_env.append(f"{k}={v}".encode())
workers = []
for _ in range(world_size):
worker = multiprocessing.Process(target=run_rabit_api_worker,
args=(rabit_env, world_size))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
assert worker.exitcode == 0


def run_federated_worker(port, world_size, rank):
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
federated_server_address=f'localhost:{port}',
Expand Down

0 comments on commit 5b76acc

Please sign in to comment.