Skip to content

Commit

Permalink
[pyspark] Add tracker_on_driver to decide where the tracker will be l…
Browse files Browse the repository at this point in the history
…aunched
  • Loading branch information
wbo4958 committed May 13, 2024
1 parent d81e319 commit b3a24cd
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 9 deletions.
34 changes: 29 additions & 5 deletions python-package/xgboost/spark/core.py
Expand Up @@ -86,6 +86,7 @@
CommunicatorContext,
_get_default_params_from_func,
_get_gpu_id,
_get_host_ip,
_get_max_num_concurrent_tasks,
_get_rabit_args,
_get_spark_session,
Expand Down Expand Up @@ -121,6 +122,7 @@
"repartition_random_shuffle",
"pred_contrib_col",
"use_gpu",
"tracker_on_driver",
]

_non_booster_params = ["missing", "n_estimators", "feature_types", "feature_weights"]
Expand Down Expand Up @@ -246,6 +248,13 @@ class _SparkXGBParams(
"A list of str to specify feature names.",
TypeConverters.toList,
)
tracker_on_driver = Param(
Params._dummy(),
"tracker_on_driver",
"A boolean variable. Set tracker_on_driver to true if you want the tracker to be launched "
"on the driver side; otherwise, it will be launched on the executor side.",
TypeConverters.toBoolean,
)

def set_device(self, value: str) -> "_SparkXGBParams":
"""Set device, optional value: cpu, cuda, gpu"""
Expand Down Expand Up @@ -616,6 +625,7 @@ def __init__(self) -> None:
feature_names=None,
feature_types=None,
arbitrary_params_dict={},
tracker_on_driver=True,
)

self.logger = get_logger(self.__class__.__name__)
Expand Down Expand Up @@ -1052,6 +1062,16 @@ def _fit(self, dataset: DataFrame) -> "_SparkXGBModel":

num_workers = self.getOrDefault(self.num_workers)

run_tracker_on_driver = self.getOrDefault(self.tracker_on_driver)

rabit_args = {}
if run_tracker_on_driver:
driver_host = (
_get_spark_session().sparkContext.getConf().get("spark.driver.host")
)
assert driver_host is not None
rabit_args = _get_rabit_args(driver_host, num_workers)

log_level = get_logger_level(_LOG_TAG)

def _train_booster(
Expand Down Expand Up @@ -1087,21 +1107,25 @@ def _train_booster(
if use_qdm and (booster_params.get("max_bin", None) is not None):
dmatrix_kwargs["max_bin"] = booster_params["max_bin"]

_rabit_args = {}
_rabit_args = rabit_args
if context.partitionId() == 0:
_rabit_args = _get_rabit_args(context, num_workers)
if not run_tracker_on_driver:
_rabit_args = _get_rabit_args(_get_host_ip(context), num_workers)
get_logger(_LOG_TAG, log_level).info(msg)

worker_message = {
"rabit_msg": _rabit_args,
worker_message: Dict[str, Any] = {
"use_qdm": use_qdm,
}

if not run_tracker_on_driver:
worker_message["rabit_msg"] = _rabit_args

messages = context.allGather(message=json.dumps(worker_message))
if len(set(json.loads(x)["use_qdm"] for x in messages)) != 1:
raise RuntimeError("The workers' cudf environments are in-consistent ")

_rabit_args = json.loads(messages[0])["rabit_msg"]
if not run_tracker_on_driver:
_rabit_args = json.loads(messages[0])["rabit_msg"]

evals_result: Dict[str, Any] = {}
with CommunicatorContext(context, **_rabit_args):
Expand Down
12 changes: 12 additions & 0 deletions python-package/xgboost/spark/estimator.py
Expand Up @@ -161,6 +161,9 @@ class SparkXGBRegressor(_SparkXGBEstimator):
Boolean value to specify if enabling sparse data optimization, if True,
Xgboost DMatrix object will be constructed from sparse matrix instead of
dense matrix.
tracker_on_driver:
Boolean value to indicate whether the tracker should be launched on the driver side or
the executor side.
kwargs:
A dictionary of xgboost parameters, please refer to
Expand Down Expand Up @@ -215,6 +218,7 @@ def __init__( # pylint:disable=too-many-arguments
force_repartition: bool = False,
repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False,
tracker_on_driver: bool = True,
**kwargs: Any,
) -> None:
super().__init__()
Expand Down Expand Up @@ -341,6 +345,9 @@ class SparkXGBClassifier(_SparkXGBEstimator, HasProbabilityCol, HasRawPrediction
Boolean value to specify if enabling sparse data optimization, if True,
Xgboost DMatrix object will be constructed from sparse matrix instead of
dense matrix.
tracker_on_driver:
Boolean value to indicate whether the tracker should be launched on the driver side or
the executor side.
kwargs:
A dictionary of xgboost parameters, please refer to
Expand Down Expand Up @@ -395,6 +402,7 @@ def __init__( # pylint:disable=too-many-arguments
force_repartition: bool = False,
repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False,
tracker_on_driver: bool = True,
**kwargs: Any,
) -> None:
super().__init__()
Expand Down Expand Up @@ -524,6 +532,9 @@ class SparkXGBRanker(_SparkXGBEstimator):
Boolean value to specify if enabling sparse data optimization, if True,
Xgboost DMatrix object will be constructed from sparse matrix instead of
dense matrix.
tracker_on_driver:
Boolean value to indicate whether the tracker should be launched on the driver side or
the executor side.
kwargs:
A dictionary of xgboost parameters, please refer to
Expand Down Expand Up @@ -584,6 +595,7 @@ def __init__( # pylint:disable=too-many-arguments
force_repartition: bool = False,
repartition_random_shuffle: bool = False,
enable_sparse_data_optim: bool = False,
tracker_on_driver: bool = True,
**kwargs: Any,
) -> None:
super().__init__()
Expand Down
7 changes: 3 additions & 4 deletions python-package/xgboost/spark/utils.py
Expand Up @@ -51,10 +51,9 @@ def __init__(self, context: BarrierTaskContext, **args: Any) -> None:
super().__init__(**args)


def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
def _start_tracker(host: str, n_workers: int) -> Dict[str, Any]:
"""Start Rabit tracker with n_workers"""
env: Dict[str, Any] = {"DMLC_NUM_WORKER": n_workers}
host = _get_host_ip(context)
rabit_context = RabitTracker(host_ip=host, n_workers=n_workers, sortby="task")
env.update(rabit_context.worker_envs())
rabit_context.start(n_workers)
Expand All @@ -64,9 +63,9 @@ def _start_tracker(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any
return env


def _get_rabit_args(context: BarrierTaskContext, n_workers: int) -> Dict[str, Any]:
def _get_rabit_args(host: str, n_workers: int) -> Dict[str, Any]:
"""Get rabit context arguments to send to each worker."""
env = _start_tracker(context, n_workers)
env = _start_tracker(host, n_workers)
return env


Expand Down

0 comments on commit b3a24cd

Please sign in to comment.