Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add back xgboost.rabit for backwards compatibility (#8408) #8411

Merged
merged 1 commit into from Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion python-package/xgboost/__init__.py
Expand Up @@ -4,7 +4,7 @@
"""

from . import tracker # noqa
from . import collective, dask
from . import collective, dask, rabit
from .core import (
Booster,
DataIter,
Expand Down
168 changes: 168 additions & 0 deletions python-package/xgboost/rabit.py
@@ -0,0 +1,168 @@
"""Compatibility shim for xgboost.rabit; to be removed in 2.0"""
import logging
import warnings
from enum import IntEnum, unique
from typing import Any, TypeVar, Callable, Optional, List

import numpy as np

from . import collective

LOGGER = logging.getLogger("[xgboost.rabit]")


def _deprecation_warning() -> str:
return (
"The xgboost.rabit submodule is marked as deprecated in 1.7 and will be removed "
"in 2.0. Please use xgboost.collective instead."
)


def init(args: Optional[List[bytes]] = None) -> None:
"""Initialize the rabit library with arguments"""
warnings.warn(_deprecation_warning(), FutureWarning)
parsed = {}
if args:
for arg in args:
kv = arg.decode().split('=')
if len(kv) == 2:
parsed[kv[0]] = kv[1]
collective.init(**parsed)


def finalize() -> None:
"""Finalize the process, notify tracker everything is done."""
collective.finalize()


def get_rank() -> int:
"""Get rank of current process.
Returns
-------
rank : int
Rank of current process.
"""
return collective.get_rank()


def get_world_size() -> int:
"""Get total number workers.
Returns
-------
n : int
Total number of process.
"""
return collective.get_world_size()


def is_distributed() -> int:
"""If rabit is distributed."""
return collective.is_distributed()


def tracker_print(msg: Any) -> None:
"""Print message to the tracker.
This function can be used to communicate the information of
the progress to the tracker
Parameters
----------
msg : str
The message to be printed to tracker.
"""
collective.communicator_print(msg)


def get_processor_name() -> bytes:
"""Get the processor name.
Returns
-------
name : str
the name of processor(host)
"""
return collective.get_processor_name().encode()


T = TypeVar("T") # pylint:disable=invalid-name


def broadcast(data: T, root: int) -> T:
"""Broadcast object from one node to all other nodes.
Parameters
----------
data : any type that can be pickled
Input data, if current rank does not equal root, this can be None
root : int
Rank of the node to broadcast data from.
Returns
-------
object : int
the result of broadcast.
"""
return collective.broadcast(data, root)


@unique
class Op(IntEnum):
"""Supported operations for rabit."""
MAX = 0
MIN = 1
SUM = 2
OR = 3


def allreduce( # pylint:disable=invalid-name
data: np.ndarray, op: Op, prepare_fun: Optional[Callable[[np.ndarray], None]] = None
) -> np.ndarray:
"""Perform allreduce, return the result.
Parameters
----------
data :
Input data.
op :
Reduction operators, can be MIN, MAX, SUM, BITOR
prepare_fun :
Lazy preprocessing function, if it is not None, prepare_fun(data)
will be called by the function before performing allreduce, to initialize the data
If the result of Allreduce can be recovered directly,
then prepare_fun will NOT be called
Returns
-------
result :
The result of allreduce, have same shape as data
Notes
-----
This function is not thread-safe.
"""
if prepare_fun is None:
return collective.allreduce(data, collective.Op(op))
raise Exception("preprocessing function is no longer supported")


def version_number() -> int:
"""Returns version number of current stored model.
This means how many calls to CheckPoint we made so far.
Returns
-------
version : int
Version number of currently stored model
"""
return 0


class RabitContext:
"""A context controlling rabit initialization and finalization."""

def __init__(self, args: List[bytes] = None) -> None:
if args is None:
args = []
self.args = args

def __enter__(self) -> None:
init(self.args)
assert is_distributed()
LOGGER.warning(_deprecation_warning())
LOGGER.debug("-------------- rabit say hello ------------------")

def __exit__(self, *args: List) -> None:
finalize()
LOGGER.debug("--------------- rabit say bye ------------------")
31 changes: 31 additions & 0 deletions tests/python/test_collective.py
Expand Up @@ -39,6 +39,37 @@ def test_rabit_communicator():
assert worker.exitcode == 0


# TODO(rongou): remove this once we remove the rabit api.
def run_rabit_api_worker(rabit_env, world_size):
with xgb.rabit.RabitContext(rabit_env):
assert xgb.rabit.get_world_size() == world_size
assert xgb.rabit.is_distributed()
assert xgb.rabit.get_processor_name().decode() == socket.gethostname()
ret = xgb.rabit.broadcast('test1234', 0)
assert str(ret) == 'test1234'
ret = xgb.rabit.allreduce(np.asarray([1, 2, 3]), xgb.rabit.Op.SUM)
assert np.array_equal(ret, np.asarray([2, 4, 6]))


# TODO(rongou): remove this once we remove the rabit api.
def test_rabit_api():
world_size = 2
tracker = RabitTracker(host_ip='127.0.0.1', n_workers=world_size)
tracker.start(world_size)
rabit_env = []
for k, v in tracker.worker_envs().items():
rabit_env.append(f"{k}={v}".encode())
workers = []
for _ in range(world_size):
worker = multiprocessing.Process(target=run_rabit_api_worker,
args=(rabit_env, world_size))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
assert worker.exitcode == 0


def run_federated_worker(port, world_size, rank):
with xgb.collective.CommunicatorContext(xgboost_communicator='federated',
federated_server_address=f'localhost:{port}',
Expand Down