From 6e7abadd4ffb33a5b317f894f30ab647518e1ba9 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Sat, 8 Oct 2022 14:53:09 -0700 Subject: [PATCH] Switch to XGBoost Communicator API (#996) --- nvflare/app_opt/xgboost/README.rst | 4 ++-- .../xgboost/histogram_based/executor_base.py | 24 ++++++++----------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/nvflare/app_opt/xgboost/README.rst b/nvflare/app_opt/xgboost/README.rst index e47ee2db56..bde999291f 100644 --- a/nvflare/app_opt/xgboost/README.rst +++ b/nvflare/app_opt/xgboost/README.rst @@ -239,7 +239,7 @@ function is required, class CustomXGBExecutor(XGBExecutorBase): def xgb_train(self, params: XGBoostParams, fl_ctx: FLContext) -> Shareable: - with xgb.rabit.RabitContext([e.encode() for e in params.rabit_env]): + with xgb.collective.CommunicatorContext(**params.communicator_env): dtrain = xgb.DMatrix(params.train_data) dtest = xgb.DMatrix(params.test_data) watchlist = [(dtest, "eval"), (dtrain, "train")] @@ -259,7 +259,7 @@ function is required, run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN) run_dir = workspace.get_run_dir(run_number) bst.save_model(os.path.join(run_dir, "test.model.json")) - xgb.rabit.tracker_print("Finished training\n") + xgb.collective.communicator_print("Finished training\n") return make_reply(ReturnCode.OK) diff --git a/nvflare/app_opt/xgboost/histogram_based/executor_base.py b/nvflare/app_opt/xgboost/histogram_based/executor_base.py index f8da048483..288f596e57 100644 --- a/nvflare/app_opt/xgboost/histogram_based/executor_base.py +++ b/nvflare/app_opt/xgboost/histogram_based/executor_base.py @@ -170,27 +170,23 @@ def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) - verbose_eval=self.verbose_eval, ) - rabit_env = [ - f"federated_server_address={self._server_address}:{xgb_fl_server_port}", - f"federated_world_size={self.world_size}", - f"federated_rank={self.rank}", - ] + communicator_env = { + "federated_server_address": f"{self._server_address}:{xgb_fl_server_port}", + "federated_world_size": {self.world_size}, + "federated_rank": {self.rank}, + } if secure_comm: if not self._get_certificates(fl_ctx): return make_reply(ReturnCode.ERROR) - rabit_env.extend( - [ - f"federated_server_cert={self._ca_cert_path}", - f"federated_client_key={self._client_key_path}", - f"federated_client_cert={self._client_cert_path}", - ] - ) + communicator_env["federated_server_cert"] = self._ca_cert_path + communicator_env["federated_client_key"] = self._client_key_path + communicator_env["federated_client_cert"] = self._client_cert_path try: - with xgb.rabit.RabitContext([e.encode() for e in rabit_env]): + with xgb.collective.CommunicatorContext(**communicator_env): result = self.xgb_train(params, fl_ctx) - xgb.rabit.tracker_print("Finished training\n") + xgb.collective.communicator_print("Finished training\n") except BaseException as e: secure_log_traceback() self.log_error(fl_ctx, f"Exception happens when running xgb train: {secure_format_exception(e)}")