Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility for xgboost>=1.7.0, fix master CI #242

Merged
merged 6 commits into from Oct 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements-test.txt
Expand Up @@ -6,10 +6,11 @@ packaging
petastorm
pytest
pyarrow
ray[tune]
ray[tune, data]
scikit-learn
modin
dask

#workaround for now
protobuf<4.0.0
tensorboardX==2.2
51 changes: 38 additions & 13 deletions xgboost_ray/main.py
Expand Up @@ -25,6 +25,17 @@
class EarlyStopException(XGBoostError):
pass


# From xgboost>=1.7.0, rabit is replaced by a collective communicator
try:
from xgboost.collective import CommunicatorContext
rabit = None
HAS_COLLECTIVE = True
except ImportError:
from xgboost import rabit # noqa
CommunicatorContext = None
HAS_COLLECTIVE = False

from xgboost_ray.callback import DistributedCallback, \
DistributedCallbackContainer
from xgboost_ray.compat import TrainingCallback, RabitTracker, LEGACY_CALLBACK
Expand Down Expand Up @@ -66,7 +77,7 @@ def inner_f(*args, **kwargs):
RayDeviceQuantileDMatrix, RayDataIter, concat_dataframes, \
LEGACY_MATRIX
from xgboost_ray.session import init_session, put_queue, \
set_session_queue
set_session_queue, get_rabit_rank


def _get_environ(item: str, old_val: Any):
Expand Down Expand Up @@ -237,25 +248,40 @@ def _stop_rabit_tracker(rabit_process: multiprocessing.Process):
rabit_process.terminate()


class _RabitContext:
class _RabitContextBase:
"""This context is used by local training actors to connect to the
Rabit tracker.

Args:
actor_id (str): Unique actor ID
args (list): Arguments for Rabit initialisation. These are
args (dict): Arguments for Rabit initialisation. These are
environment variables to configure Rabit clients.
"""

def __init__(self, actor_id, args):
def __init__(self, actor_id: int, args: dict):
args["DMLC_TASK_ID"] = "[xgboost.ray]:" + actor_id
self.args = args
self.args.append(("DMLC_TASK_ID=[xgboost.ray]:" + actor_id).encode())

def __enter__(self):
xgb.rabit.init(self.args)

def __exit__(self, *args):
xgb.rabit.finalize()
# From xgboost>=1.7.0, rabit is replaced by a collective communicator
if HAS_COLLECTIVE:

class _RabitContext(_RabitContextBase, CommunicatorContext):
pass

else:

class _RabitContext(_RabitContextBase):
def __init__(self, actor_id: int, args: dict):
super().__init__(actor_id, args)
self._list_args = [("%s=%s" % item).encode()
for item in self.args.items()]

def __enter__(self):
xgb.rabit.init(self._list_args)

def __exit__(self, *args):
xgb.rabit.finalize()


def _ray_get_actor_cpus():
Expand Down Expand Up @@ -517,12 +543,12 @@ def _save_checkpoint_callback(self):

class _SaveInternalCheckpointCallback(TrainingCallback):
def after_iteration(self, model, epoch, evals_log):
if xgb.rabit.get_rank() == 0 and \
if get_rabit_rank() == 0 and \
epoch % this.checkpoint_frequency == 0:
put_queue(_Checkpoint(epoch, pickle.dumps(model)))

def after_training(self, model):
if xgb.rabit.get_rank() == 0:
if get_rabit_rank() == 0:
put_queue(_Checkpoint(-1, pickle.dumps(model)))
return model

Expand Down Expand Up @@ -1054,8 +1080,7 @@ def handle_actor_failure(actor_id):
maybe_log("[RayXGBoost] Starting XGBoost training.")

# Start Rabit tracker for gradient sharing
rabit_process, env = _start_rabit_tracker(alive_actors)
rabit_args = [("%s=%s" % item).encode() for item in env.items()]
rabit_process, rabit_args = _start_rabit_tracker(alive_actors)

# Load checkpoint if we have one. In that case we need to adjust the
# number of training rounds.
Expand Down
6 changes: 5 additions & 1 deletion xgboost_ray/session.py
Expand Up @@ -63,7 +63,11 @@ def get_actor_rank() -> int:
@PublicAPI
def get_rabit_rank() -> int:
import xgboost as xgb
return xgb.rabit.get_rank()
try:
# From xgboost>=1.7.0, rabit is replaced by a collective communicator
return xgb.collective.get_rank()
except (ImportError, AttributeError):
return xgb.rabit.get_rank()


@PublicAPI
Expand Down