From ae2c2c55c78f2c3f538296b43f67f932999b947c Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 5 May 2022 23:44:30 -0700 Subject: [PATCH 1/9] federated learning demo using nvflare --- demo/nvflare/README.md | 51 +++++++++++++ demo/nvflare/config/config_fed_client.json | 22 ++++++ demo/nvflare/config/config_fed_server.json | 22 ++++++ demo/nvflare/custom/controller.py | 69 ++++++++++++++++++ demo/nvflare/custom/trainer.py | 83 ++++++++++++++++++++++ demo/nvflare/prepare_data.sh | 23 ++++++ 6 files changed, 270 insertions(+) create mode 100644 demo/nvflare/README.md create mode 100755 demo/nvflare/config/config_fed_client.json create mode 100755 demo/nvflare/config/config_fed_server.json create mode 100644 demo/nvflare/custom/controller.py create mode 100644 demo/nvflare/custom/trainer.py create mode 100755 demo/nvflare/prepare_data.sh diff --git a/demo/nvflare/README.md b/demo/nvflare/README.md new file mode 100644 index 000000000000..1ba169a7b1de --- /dev/null +++ b/demo/nvflare/README.md @@ -0,0 +1,51 @@ +# Federated XGBoost Demo + +This directory contains a demo of Federated Learning using [NVFlare](https://nvidia.github.io/). + +To run the demo, first install NVFlare: +```console +pip install nvflare +``` + +Prepare the data: +```console +./prepare_data.sh +``` + +Start the NVFlare federated server: +```console +./poc/server/startup/start.sh +``` + +In another terminal, start the first worker: +```console +./poc/site-1/startup/start.sh +``` + +And the second worker: +```console +./poc/site-2/startup/start.sh +``` + +Then start the admin CLI, using `admin/admin` as username/password: +```console +./poc/admin/startup/fl_admin.sh +``` + +In the admin CLI, run the following commands: +```console +upload_app hello-xgboost +set_run_number 1 +deploy_app hello-xgboost all +start_app all +``` + +Once the training finishes, the model file should be written into +`./poc/site-1/run_1/test.model.json` and `./poc/site-2/run_1/test.model.json` +respectively. + +Finally, shutdown everything from the admin CLI: +```console +shutdown client +shutdown server +``` diff --git a/demo/nvflare/config/config_fed_client.json b/demo/nvflare/config/config_fed_client.json new file mode 100755 index 000000000000..39f23e9bca42 --- /dev/null +++ b/demo/nvflare/config/config_fed_client.json @@ -0,0 +1,22 @@ +{ + "format_version": 2, + "executors": [ + { + "tasks": [ + "train" + ], + "executor": { + "path": "trainer.XGBoostTrainer", + "args": { + "server_address": "localhost:9091", + "world_size": 2, + "server_cert_path": "server-cert.pem", + "client_key_path": "client-key.pem", + "client_cert_path": "client-cert.pem" + } + } + } + ], + "task_result_filters": [], + "task_data_filters": [] +} diff --git a/demo/nvflare/config/config_fed_server.json b/demo/nvflare/config/config_fed_server.json new file mode 100755 index 000000000000..32993b65215f --- /dev/null +++ b/demo/nvflare/config/config_fed_server.json @@ -0,0 +1,22 @@ +{ + "format_version": 2, + "server": { + "heart_beat_timeout": 600 + }, + "task_data_filters": [], + "task_result_filters": [], + "workflows": [ + { + "id": "server_workflow", + "path": "controller.XGBoostController", + "args": { + "port": 9091, + "world_size": 2, + "server_key_path": "server-key.pem", + "server_cert_path": "server-cert.pem", + "client_cert_path": "client-cert.pem" + } + } + ], + "components": [] +} diff --git a/demo/nvflare/custom/controller.py b/demo/nvflare/custom/controller.py new file mode 100644 index 000000000000..66f2ac942ef1 --- /dev/null +++ b/demo/nvflare/custom/controller.py @@ -0,0 +1,69 @@ +""" +Example of training controller with NVFlare +=========================================== +""" +import multiprocessing + +import xgboost.federated +from nvflare.apis.client import Client +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller, Task +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal + +from trainer import SupportedTasks + + +class XGBoostController(Controller): + def __init__(self, port: int, world_size: int, server_key_path: str, + server_cert_path: str, client_cert_path: str): + """Controller for federated XGBoost. + + Args: + port: the port for the gRPC server to listen on. + world_size: the number of sites. + server_key_path: the path to the server key file. + server_cert_path: the path to the server certificate file. + client_cert_path: the path to the client certificate file. + """ + super().__init__() + self._port = port + self._world_size = world_size + self._server_key_path = server_key_path + self._server_cert_path = server_cert_path + self._client_cert_path = client_cert_path + self._server = None + self.run_dir = None + + def start_controller(self, fl_ctx: FLContext): + self._server = multiprocessing.Process( + target=xgboost.federated.run_federated_server, + args=(self._port, self._world_size, self._server_key_path, + self._server_cert_path, self._client_cert_path)) + self._server.start() + + def stop_controller(self, fl_ctx: FLContext): + if self._server: + self._server.terminate() + + def process_result_of_unknown_task(self, client: Client, task_name: str, + client_task_id: str, result: Shareable, + fl_ctx: FLContext): + self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.") + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + self.log_info(fl_ctx, "XGBoost training control flow started.") + if abort_signal.triggered: + return + task = Task(name=SupportedTasks.TRAIN, data=Shareable()) + self.broadcast_and_wait( + task=task, + min_responses=self._world_size, + fl_ctx=fl_ctx, + wait_time_after_min_received=1, + abort_signal=abort_signal, + ) + if abort_signal.triggered: + return + + self.log_info(fl_ctx, "XGBoost training control flow finished.") diff --git a/demo/nvflare/custom/trainer.py b/demo/nvflare/custom/trainer.py new file mode 100644 index 000000000000..b353522c50a4 --- /dev/null +++ b/demo/nvflare/custom/trainer.py @@ -0,0 +1,83 @@ +import os + +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReservedKey, ReturnCode, FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal + +import xgboost as xgb + + +class SupportedTasks(object): + TRAIN = "train" + + +class XGBoostTrainer(Executor): + def __init__(self, server_address: str, world_size: int, server_cert_path: str, + client_key_path: str, client_cert_path: str): + """Trainer for federated XGBoost. + + Args: + data_root: directory with local train/test data. + """ + super().__init__() + self._server_address = server_address + self._world_size = world_size + self._server_cert_path = server_cert_path + self._client_key_path = client_key_path + self._client_cert_path = client_cert_path + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, + abort_signal: Signal) -> Shareable: + self.log_info(fl_ctx, f"Executing {task_name}") + try: + if task_name == SupportedTasks.TRAIN: + self._do_training(fl_ctx) + return make_reply(ReturnCode.OK) + else: + self.log_error(fl_ctx, f"{task_name} is not a supported task.") + return make_reply(ReturnCode.TASK_UNKNOWN) + except BaseException as e: + self.log_exception(fl_ctx, + f"Task {task_name} failed. Exception: {e.__str__()}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + def _do_training(self, fl_ctx: FLContext): + client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME) + rank = int(client_name.split('-')[1]) - 1 + rabit_env = [ + f'federated_server_address={self._server_address}', + f'federated_world_size={self._world_size}', + f'federated_rank={rank}', + f'federated_server_cert={self._server_cert_path}', + f'federated_client_key={self._client_key_path}', + f'federated_client_cert={self._client_cert_path}' + ] + xgb.rabit.init([e.encode() for e in rabit_env]) + + # Load file, file will not be sharded in federated mode. + dtrain = xgb.DMatrix('agaricus.txt.train-%s' % client_name) + dtest = xgb.DMatrix('agaricus.txt.test-%s' % client_name) + + # Specify parameters via map, definition are same as c++ version + param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} + + # Specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + num_round = 20 + + # Run training, all the features in training API is available. + bst = xgb.train(param, dtrain, num_round, evals=watchlist, + early_stopping_rounds=2) + + # Save the model. + workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN) + run_dir = workspace.get_run_dir(run_number) + bst.save_model(os.path.join(run_dir, "test.model.json")) + xgb.rabit.tracker_print("Finished training\n") + + # Notify the tracker all training has been successful + # This is only needed in distributed training. + xgb.rabit.finalize() diff --git a/demo/nvflare/prepare_data.sh b/demo/nvflare/prepare_data.sh new file mode 100755 index 000000000000..9227476d75a9 --- /dev/null +++ b/demo/nvflare/prepare_data.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +set -e + +rm -fr ./agaricus* ./*.pem ./poc + +world_size=2 + +# Generate server and client certificates. +openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost" +openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost" + +# Split train and test files manually to simulate a federated environment. +split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.train agaricus.txt.train-site- +split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.test agaricus.txt.test-site- + +poc -n 2 +mkdir -p poc/admin/transfer/hello-xgboost +cp -fr config custom poc/admin/transfer/hello-xgboost +cp server-*.pem client-cert.pem poc/server/ +for id in $(eval echo "{1..$world_size}"); do + cp server-cert.pem client-*.pem agaricus.txt.{train,test}-site-"$id" poc/site-"$id"/ +done From a0a9ee0660eb251dd841ce8a84ab22ce35860889 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 6 May 2022 15:37:24 -0700 Subject: [PATCH 2/9] use the same name for data files --- demo/nvflare/custom/trainer.py | 4 ++-- demo/nvflare/prepare_data.sh | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/demo/nvflare/custom/trainer.py b/demo/nvflare/custom/trainer.py index b353522c50a4..ac27e3bacad2 100644 --- a/demo/nvflare/custom/trainer.py +++ b/demo/nvflare/custom/trainer.py @@ -57,8 +57,8 @@ def _do_training(self, fl_ctx: FLContext): xgb.rabit.init([e.encode() for e in rabit_env]) # Load file, file will not be sharded in federated mode. - dtrain = xgb.DMatrix('agaricus.txt.train-%s' % client_name) - dtest = xgb.DMatrix('agaricus.txt.test-%s' % client_name) + dtrain = xgb.DMatrix('agaricus.txt.train') + dtest = xgb.DMatrix('agaricus.txt.test') # Specify parameters via map, definition are same as c++ version param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} diff --git a/demo/nvflare/prepare_data.sh b/demo/nvflare/prepare_data.sh index 9227476d75a9..838d0b28f845 100755 --- a/demo/nvflare/prepare_data.sh +++ b/demo/nvflare/prepare_data.sh @@ -19,5 +19,7 @@ mkdir -p poc/admin/transfer/hello-xgboost cp -fr config custom poc/admin/transfer/hello-xgboost cp server-*.pem client-cert.pem poc/server/ for id in $(eval echo "{1..$world_size}"); do - cp server-cert.pem client-*.pem agaricus.txt.{train,test}-site-"$id" poc/site-"$id"/ + cp server-cert.pem client-*.pem poc/site-"$id"/ + cp agaricus.txt.train-site-"$id" poc/site-"$id"/agaricus.txt.train + cp agaricus.txt.test-site-"$id" poc/site-"$id"/agaricus.txt.test done From 73d687e1fe0f80cbeb78ea693d634c51af659da8 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 6 May 2022 15:54:28 -0700 Subject: [PATCH 3/9] print eval in both sites --- demo/nvflare/custom/trainer.py | 6 ++++-- plugin/federated/engine_federated.cc | 4 +--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/demo/nvflare/custom/trainer.py b/demo/nvflare/custom/trainer.py index ac27e3bacad2..5a40a20c1246 100644 --- a/demo/nvflare/custom/trainer.py +++ b/demo/nvflare/custom/trainer.py @@ -1,12 +1,13 @@ import os from nvflare.apis.executor import Executor -from nvflare.apis.fl_constant import ReservedKey, ReturnCode, FLContextKey +from nvflare.apis.fl_constant import ReturnCode, FLContextKey from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable, make_reply from nvflare.apis.signal import Signal import xgboost as xgb +from xgboost import callback class SupportedTasks(object): @@ -69,7 +70,8 @@ def _do_training(self, fl_ctx: FLContext): # Run training, all the features in training API is available. bst = xgb.train(param, dtrain, num_round, evals=watchlist, - early_stopping_rounds=2) + early_stopping_rounds=2, verbose_eval=False, + callbacks=[callback.EvaluationMonitor(rank=rank)]) # Save the model. workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index ed7252ba117c..d18e9e095d0d 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -111,9 +111,7 @@ class FederatedEngine : public IEngine { void TrackerPrint(const std::string &msg) override { // simply print information into the tracker - if (GetRank() == 0) { - utils::Printf("%s", msg.c_str()); - } + utils::Printf("%s", msg.c_str()); } private: From 4134649b679fda381f332bd9208a2526d826eac3 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 6 May 2022 15:55:39 -0700 Subject: [PATCH 4/9] clean up readme --- demo/nvflare/README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/demo/nvflare/README.md b/demo/nvflare/README.md index 1ba169a7b1de..2db61f1d473d 100644 --- a/demo/nvflare/README.md +++ b/demo/nvflare/README.md @@ -3,37 +3,37 @@ This directory contains a demo of Federated Learning using [NVFlare](https://nvidia.github.io/). To run the demo, first install NVFlare: -```console +```shell pip install nvflare ``` Prepare the data: -```console +```shell ./prepare_data.sh ``` Start the NVFlare federated server: -```console +```shell ./poc/server/startup/start.sh ``` In another terminal, start the first worker: -```console +```shell ./poc/site-1/startup/start.sh ``` And the second worker: -```console +```shell ./poc/site-2/startup/start.sh ``` Then start the admin CLI, using `admin/admin` as username/password: -```console +```shell ./poc/admin/startup/fl_admin.sh ``` In the admin CLI, run the following commands: -```console +```shell upload_app hello-xgboost set_run_number 1 deploy_app hello-xgboost all @@ -45,7 +45,7 @@ Once the training finishes, the model file should be written into respectively. Finally, shutdown everything from the admin CLI: -```console +```shell shutdown client shutdown server ``` From ac2504fe0c3d7a84da9d4775c802ae3122b11032 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Fri, 6 May 2022 15:57:31 -0700 Subject: [PATCH 5/9] fix nvflare link --- demo/nvflare/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/demo/nvflare/README.md b/demo/nvflare/README.md index 2db61f1d473d..a758b743e5fd 100644 --- a/demo/nvflare/README.md +++ b/demo/nvflare/README.md @@ -1,6 +1,7 @@ # Federated XGBoost Demo -This directory contains a demo of Federated Learning using [NVFlare](https://nvidia.github.io/). +This directory contains a demo of Federated Learning using +[NVFlare](https://nvidia.github.io/NVFlare/). To run the demo, first install NVFlare: ```shell From d8ffca33a0e569de2c64ced0739101c3acbd288f Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 9 May 2022 15:02:20 -0700 Subject: [PATCH 6/9] cleanup docs --- demo/nvflare/README.md | 7 +++++-- demo/nvflare/custom/controller.py | 1 - demo/nvflare/custom/trainer.py | 6 +++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/demo/nvflare/README.md b/demo/nvflare/README.md index a758b743e5fd..ec939f91605f 100644 --- a/demo/nvflare/README.md +++ b/demo/nvflare/README.md @@ -1,9 +1,12 @@ -# Federated XGBoost Demo +# Federated XGBoost Demo using NVFlare This directory contains a demo of Federated Learning using [NVFlare](https://nvidia.github.io/NVFlare/). -To run the demo, first install NVFlare: +To run the demo, first build XGBoost with the federated learning plugin enabled (see the +[README](../../plugin/federated/README.md)). + +Install NVFlare: ```shell pip install nvflare ``` diff --git a/demo/nvflare/custom/controller.py b/demo/nvflare/custom/controller.py index 66f2ac942ef1..989a405bb68c 100644 --- a/demo/nvflare/custom/controller.py +++ b/demo/nvflare/custom/controller.py @@ -33,7 +33,6 @@ def __init__(self, port: int, world_size: int, server_key_path: str, self._server_cert_path = server_cert_path self._client_cert_path = client_cert_path self._server = None - self.run_dir = None def start_controller(self, fl_ctx: FLContext): self._server = multiprocessing.Process( diff --git a/demo/nvflare/custom/trainer.py b/demo/nvflare/custom/trainer.py index 5a40a20c1246..29d76adada12 100644 --- a/demo/nvflare/custom/trainer.py +++ b/demo/nvflare/custom/trainer.py @@ -20,7 +20,11 @@ def __init__(self, server_address: str, world_size: int, server_cert_path: str, """Trainer for federated XGBoost. Args: - data_root: directory with local train/test data. + server_address: address for the gRPC server to connect to. + world_size: the number of sites. + server_cert_path: the path to the server certificate file. + client_key_path: the path to the client key file. + client_cert_path: the path to the client certificate file. """ super().__init__() self._server_address = server_address From 2ef12309a0d5816343f7a8c7476eac550c98f9b8 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Wed, 11 May 2022 09:17:33 -0700 Subject: [PATCH 7/9] note on nvflare python version --- demo/nvflare/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/demo/nvflare/README.md b/demo/nvflare/README.md index ec939f91605f..0e13c01dd68a 100644 --- a/demo/nvflare/README.md +++ b/demo/nvflare/README.md @@ -6,7 +6,7 @@ This directory contains a demo of Federated Learning using To run the demo, first build XGBoost with the federated learning plugin enabled (see the [README](../../plugin/federated/README.md)). -Install NVFlare: +Install NVFlare (note that currently NVFlare only supports Python 3.8): ```shell pip install nvflare ``` From 5e8b12ffd6e229f9150d5936243b3f8156659326 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 12 May 2022 11:27:48 -0700 Subject: [PATCH 8/9] move RabitContext to rabit.py --- demo/nvflare/README.md | 2 +- demo/nvflare/custom/trainer.py | 43 ++++++++++++-------------- python-package/xgboost/dask.py | 17 +++------- python-package/xgboost/rabit.py | 19 ++++++++++++ tests/python-gpu/test_gpu_with_dask.py | 2 +- tests/python/test_tracker.py | 4 +-- tests/python/test_with_dask.py | 2 +- 7 files changed, 47 insertions(+), 42 deletions(-) diff --git a/demo/nvflare/README.md b/demo/nvflare/README.md index 0e13c01dd68a..226a90c38765 100644 --- a/demo/nvflare/README.md +++ b/demo/nvflare/README.md @@ -1,4 +1,4 @@ -# Federated XGBoost Demo using NVFlare +# Experimental Support of Federated XGBoost using NVFlare This directory contains a demo of Federated Learning using [NVFlare](https://nvidia.github.io/NVFlare/). diff --git a/demo/nvflare/custom/trainer.py b/demo/nvflare/custom/trainer.py index 29d76adada12..9403fec00215 100644 --- a/demo/nvflare/custom/trainer.py +++ b/demo/nvflare/custom/trainer.py @@ -59,31 +59,26 @@ def _do_training(self, fl_ctx: FLContext): f'federated_client_key={self._client_key_path}', f'federated_client_cert={self._client_cert_path}' ] - xgb.rabit.init([e.encode() for e in rabit_env]) + with xgb.rabit.RabitContext([e.encode() for e in rabit_env]): + # Load file, file will not be sharded in federated mode. + dtrain = xgb.DMatrix('agaricus.txt.train') + dtest = xgb.DMatrix('agaricus.txt.test') - # Load file, file will not be sharded in federated mode. - dtrain = xgb.DMatrix('agaricus.txt.train') - dtest = xgb.DMatrix('agaricus.txt.test') + # Specify parameters via map, definition are same as c++ version + param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} - # Specify parameters via map, definition are same as c++ version - param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} + # Specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + num_round = 20 - # Specify validations set to watch performance - watchlist = [(dtest, 'eval'), (dtrain, 'train')] - num_round = 20 + # Run training, all the features in training API is available. + bst = xgb.train(param, dtrain, num_round, evals=watchlist, + early_stopping_rounds=2, verbose_eval=False, + callbacks=[callback.EvaluationMonitor(rank=rank)]) - # Run training, all the features in training API is available. - bst = xgb.train(param, dtrain, num_round, evals=watchlist, - early_stopping_rounds=2, verbose_eval=False, - callbacks=[callback.EvaluationMonitor(rank=rank)]) - - # Save the model. - workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) - run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN) - run_dir = workspace.get_run_dir(run_number) - bst.save_model(os.path.join(run_dir, "test.model.json")) - xgb.rabit.tracker_print("Finished training\n") - - # Notify the tracker all training has been successful - # This is only needed in distributed training. - xgb.rabit.finalize() + # Save the model. + workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN) + run_dir = workspace.get_run_dir(run_number) + bst.save_model(os.path.join(run_dir, "test.model.json")) + xgb.rabit.tracker_print("Finished training\n") diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 942893f0a32d..97f3ab1acf4f 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -97,7 +97,7 @@ TrainReturnT = Dict[str, Any] # type:ignore __all__ = [ - "RabitContext", + "DaskRabitContext", "DaskDMatrix", "DaskDeviceQuantileDMatrix", "DaskXGBRegressor", @@ -224,25 +224,16 @@ def _assert_dask_support() -> None: LOGGER.warning(msg) -class RabitContext: +class DaskRabitContext(rabit.RabitContext): """A context controlling rabit initialization and finalization.""" def __init__(self, args: List[bytes]) -> None: - self.args = args + super().__init__(args) worker = distributed.get_worker() self.args.append( ("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode() ) - def __enter__(self) -> None: - rabit.init(self.args) - assert rabit.is_distributed() - LOGGER.debug("-------------- rabit say hello ------------------") - - def __exit__(self, *args: List) -> None: - rabit.finalize() - LOGGER.debug("--------------- rabit say bye ------------------") - def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements """To be replaced with dask builtin.""" @@ -953,7 +944,7 @@ def dispatched_train( n_threads = worker.nthreads local_param.update({"nthread": n_threads, "n_jobs": n_threads}) local_history: TrainingCallback.EvalsLog = {} - with RabitContext(rabit_args), config.config_context(**global_config): + with DaskRabitContext(rabit_args), config.config_context(**global_config): Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads) evals: List[Tuple[DMatrix, str]] = [] for i, ref in enumerate(refs): diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py index a28448df8a67..465a5611a2d1 100644 --- a/python-package/xgboost/rabit.py +++ b/python-package/xgboost/rabit.py @@ -1,6 +1,7 @@ """Distributed XGBoost Rabit related API.""" import ctypes from enum import IntEnum, unique +import logging import pickle from typing import Any, TypeVar, Callable, Optional, cast, List, Union @@ -8,6 +9,8 @@ from .core import _LIB, c_str, _check_call +LOGGER = logging.getLogger("[xgboost.rabit]") + def _init_rabit() -> None: """internal library initializer.""" @@ -224,5 +227,21 @@ def version_number() -> int: return ret +class RabitContext: + """A context controlling rabit initialization and finalization.""" + + def __init__(self, args: List[bytes]) -> None: + self.args = args + + def __enter__(self) -> None: + init(self.args) + assert is_distributed() + LOGGER.debug("-------------- rabit say hello ------------------") + + def __exit__(self, *args: List) -> None: + finalize() + LOGGER.debug("--------------- rabit say bye ------------------") + + # initialization script _init_rabit() diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index 6edcbd2f5b47..f5a501170919 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -405,7 +405,7 @@ def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client) def worker_fn(worker_addr: str, data_ref: Dict) -> None: - with dxgb.RabitContext(rabit_args): + with dxgb.DaskRabitContext(rabit_args): local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7) fw_rows = local_dtrain.get_float_info("feature_weights").shape[0] assert fw_rows == local_dtrain.num_col() diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 2e113898f4de..2074af08c586 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -24,7 +24,7 @@ def test_rabit_tracker(): def run_rabit_ops(client, n_workers): from test_with_dask import _get_client_workers - from xgboost.dask import RabitContext, _get_rabit_args + from xgboost.dask import DaskRabitContext, _get_rabit_args from xgboost import rabit workers = _get_client_workers(client) @@ -34,7 +34,7 @@ def run_rabit_ops(client, n_workers): assert n_workers == n_workers_from_dask def local_test(worker_id): - with RabitContext(rabit_args): + with DaskRabitContext(rabit_args): a = 1 assert rabit.is_distributed() a = np.array([a]) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index a023112321e7..36872baeb1fb 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1489,7 +1489,7 @@ def test_no_duplicated_partition(self) -> None: n_workers = len(workers) def worker_fn(worker_addr: str, data_ref: Dict) -> None: - with xgb.dask.RabitContext(rabit_args): + with xgb.dask.DaskRabitContext(rabit_args): local_dtrain = xgb.dask._dmatrix_from_list_of_parts( **data_ref, nthread=7 ) From 57713cc373ff185e7bfac1687eeb092f5f107b84 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Sat, 14 May 2022 00:40:32 +0800 Subject: [PATCH 9/9] Change back the name. --- python-package/xgboost/dask.py | 6 +++--- tests/python-gpu/test_gpu_with_dask.py | 2 +- tests/python/test_tracker.py | 4 ++-- tests/python/test_with_dask.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 97f3ab1acf4f..b54e26c9d550 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -97,7 +97,7 @@ TrainReturnT = Dict[str, Any] # type:ignore __all__ = [ - "DaskRabitContext", + "RabitContext", "DaskDMatrix", "DaskDeviceQuantileDMatrix", "DaskXGBRegressor", @@ -224,7 +224,7 @@ def _assert_dask_support() -> None: LOGGER.warning(msg) -class DaskRabitContext(rabit.RabitContext): +class RabitContext(rabit.RabitContext): """A context controlling rabit initialization and finalization.""" def __init__(self, args: List[bytes]) -> None: @@ -944,7 +944,7 @@ def dispatched_train( n_threads = worker.nthreads local_param.update({"nthread": n_threads, "n_jobs": n_threads}) local_history: TrainingCallback.EvalsLog = {} - with DaskRabitContext(rabit_args), config.config_context(**global_config): + with RabitContext(rabit_args), config.config_context(**global_config): Xy = _dmatrix_from_list_of_parts(**train_ref, nthread=n_threads) evals: List[Tuple[DMatrix, str]] = [] for i, ref in enumerate(refs): diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index f5a501170919..6edcbd2f5b47 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -405,7 +405,7 @@ def test_data_initialization(self, local_cuda_cluster: LocalCUDACluster) -> None rabit_args = client.sync(dxgb._get_rabit_args, len(workers), None, client) def worker_fn(worker_addr: str, data_ref: Dict) -> None: - with dxgb.DaskRabitContext(rabit_args): + with dxgb.RabitContext(rabit_args): local_dtrain = dxgb._dmatrix_from_list_of_parts(**data_ref, nthread=7) fw_rows = local_dtrain.get_float_info("feature_weights").shape[0] assert fw_rows == local_dtrain.num_col() diff --git a/tests/python/test_tracker.py b/tests/python/test_tracker.py index 2074af08c586..2e113898f4de 100644 --- a/tests/python/test_tracker.py +++ b/tests/python/test_tracker.py @@ -24,7 +24,7 @@ def test_rabit_tracker(): def run_rabit_ops(client, n_workers): from test_with_dask import _get_client_workers - from xgboost.dask import DaskRabitContext, _get_rabit_args + from xgboost.dask import RabitContext, _get_rabit_args from xgboost import rabit workers = _get_client_workers(client) @@ -34,7 +34,7 @@ def run_rabit_ops(client, n_workers): assert n_workers == n_workers_from_dask def local_test(worker_id): - with DaskRabitContext(rabit_args): + with RabitContext(rabit_args): a = 1 assert rabit.is_distributed() a = np.array([a]) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 36872baeb1fb..a023112321e7 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1489,7 +1489,7 @@ def test_no_duplicated_partition(self) -> None: n_workers = len(workers) def worker_fn(worker_addr: str, data_ref: Dict) -> None: - with xgb.dask.DaskRabitContext(rabit_args): + with xgb.dask.RabitContext(rabit_args): local_dtrain = xgb.dask._dmatrix_from_list_of_parts( **data_ref, nthread=7 )