Skip to content

Commit

Permalink
Switch to XGBoost Communicator API (#996)
Browse files Browse the repository at this point in the history
  • Loading branch information
rongou committed Oct 8, 2022
1 parent 783f192 commit 6e7abad
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
4 changes: 2 additions & 2 deletions nvflare/app_opt/xgboost/README.rst
Expand Up @@ -239,7 +239,7 @@ function is required,

class CustomXGBExecutor(XGBExecutorBase):
def xgb_train(self, params: XGBoostParams, fl_ctx: FLContext) -> Shareable:
with xgb.rabit.RabitContext([e.encode() for e in params.rabit_env]):
with xgb.collective.CommunicatorContext(**params.communicator_env):
dtrain = xgb.DMatrix(params.train_data)
dtest = xgb.DMatrix(params.test_data)
watchlist = [(dtest, "eval"), (dtrain, "train")]
Expand All @@ -259,7 +259,7 @@ function is required,
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
run_dir = workspace.get_run_dir(run_number)
bst.save_model(os.path.join(run_dir, "test.model.json"))
xgb.rabit.tracker_print("Finished training\n")
xgb.collective.communicator_print("Finished training\n")

return make_reply(ReturnCode.OK)

Expand Down
24 changes: 10 additions & 14 deletions nvflare/app_opt/xgboost/histogram_based/executor_base.py
Expand Up @@ -170,27 +170,23 @@ def train(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -
verbose_eval=self.verbose_eval,
)

rabit_env = [
f"federated_server_address={self._server_address}:{xgb_fl_server_port}",
f"federated_world_size={self.world_size}",
f"federated_rank={self.rank}",
]
communicator_env = {
"federated_server_address": f"{self._server_address}:{xgb_fl_server_port}",
"federated_world_size": {self.world_size},
"federated_rank": {self.rank},
}
if secure_comm:
if not self._get_certificates(fl_ctx):
return make_reply(ReturnCode.ERROR)

rabit_env.extend(
[
f"federated_server_cert={self._ca_cert_path}",
f"federated_client_key={self._client_key_path}",
f"federated_client_cert={self._client_cert_path}",
]
)
communicator_env["federated_server_cert"] = self._ca_cert_path
communicator_env["federated_client_key"] = self._client_key_path
communicator_env["federated_client_cert"] = self._client_cert_path

try:
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
with xgb.collective.CommunicatorContext(**communicator_env):
result = self.xgb_train(params, fl_ctx)
xgb.rabit.tracker_print("Finished training\n")
xgb.collective.communicator_print("Finished training\n")
except BaseException as e:
secure_log_traceback()
self.log_error(fl_ctx, f"Exception happens when running xgb train: {secure_format_exception(e)}")
Expand Down

0 comments on commit 6e7abad

Please sign in to comment.