diff --git a/demo/nvflare/README.md b/demo/nvflare/README.md new file mode 100644 index 000000000000..226a90c38765 --- /dev/null +++ b/demo/nvflare/README.md @@ -0,0 +1,55 @@ +# Experimental Support of Federated XGBoost using NVFlare + +This directory contains a demo of Federated Learning using +[NVFlare](https://nvidia.github.io/NVFlare/). + +To run the demo, first build XGBoost with the federated learning plugin enabled (see the +[README](../../plugin/federated/README.md)). + +Install NVFlare (note that currently NVFlare only supports Python 3.8): +```shell +pip install nvflare +``` + +Prepare the data: +```shell +./prepare_data.sh +``` + +Start the NVFlare federated server: +```shell +./poc/server/startup/start.sh +``` + +In another terminal, start the first worker: +```shell +./poc/site-1/startup/start.sh +``` + +And the second worker: +```shell +./poc/site-2/startup/start.sh +``` + +Then start the admin CLI, using `admin/admin` as username/password: +```shell +./poc/admin/startup/fl_admin.sh +``` + +In the admin CLI, run the following commands: +```shell +upload_app hello-xgboost +set_run_number 1 +deploy_app hello-xgboost all +start_app all +``` + +Once the training finishes, the model file should be written into +`./poc/site-1/run_1/test.model.json` and `./poc/site-2/run_1/test.model.json` +respectively. + +Finally, shutdown everything from the admin CLI: +```shell +shutdown client +shutdown server +``` diff --git a/demo/nvflare/config/config_fed_client.json b/demo/nvflare/config/config_fed_client.json new file mode 100755 index 000000000000..39f23e9bca42 --- /dev/null +++ b/demo/nvflare/config/config_fed_client.json @@ -0,0 +1,22 @@ +{ + "format_version": 2, + "executors": [ + { + "tasks": [ + "train" + ], + "executor": { + "path": "trainer.XGBoostTrainer", + "args": { + "server_address": "localhost:9091", + "world_size": 2, + "server_cert_path": "server-cert.pem", + "client_key_path": "client-key.pem", + "client_cert_path": "client-cert.pem" + } + } + } + ], + "task_result_filters": [], + "task_data_filters": [] +} diff --git a/demo/nvflare/config/config_fed_server.json b/demo/nvflare/config/config_fed_server.json new file mode 100755 index 000000000000..32993b65215f --- /dev/null +++ b/demo/nvflare/config/config_fed_server.json @@ -0,0 +1,22 @@ +{ + "format_version": 2, + "server": { + "heart_beat_timeout": 600 + }, + "task_data_filters": [], + "task_result_filters": [], + "workflows": [ + { + "id": "server_workflow", + "path": "controller.XGBoostController", + "args": { + "port": 9091, + "world_size": 2, + "server_key_path": "server-key.pem", + "server_cert_path": "server-cert.pem", + "client_cert_path": "client-cert.pem" + } + } + ], + "components": [] +} diff --git a/demo/nvflare/custom/controller.py b/demo/nvflare/custom/controller.py new file mode 100644 index 000000000000..989a405bb68c --- /dev/null +++ b/demo/nvflare/custom/controller.py @@ -0,0 +1,68 @@ +""" +Example of training controller with NVFlare +=========================================== +""" +import multiprocessing + +import xgboost.federated +from nvflare.apis.client import Client +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller, Task +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal + +from trainer import SupportedTasks + + +class XGBoostController(Controller): + def __init__(self, port: int, world_size: int, server_key_path: str, + server_cert_path: str, client_cert_path: str): + """Controller for federated XGBoost. + + Args: + port: the port for the gRPC server to listen on. + world_size: the number of sites. + server_key_path: the path to the server key file. + server_cert_path: the path to the server certificate file. + client_cert_path: the path to the client certificate file. + """ + super().__init__() + self._port = port + self._world_size = world_size + self._server_key_path = server_key_path + self._server_cert_path = server_cert_path + self._client_cert_path = client_cert_path + self._server = None + + def start_controller(self, fl_ctx: FLContext): + self._server = multiprocessing.Process( + target=xgboost.federated.run_federated_server, + args=(self._port, self._world_size, self._server_key_path, + self._server_cert_path, self._client_cert_path)) + self._server.start() + + def stop_controller(self, fl_ctx: FLContext): + if self._server: + self._server.terminate() + + def process_result_of_unknown_task(self, client: Client, task_name: str, + client_task_id: str, result: Shareable, + fl_ctx: FLContext): + self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.") + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + self.log_info(fl_ctx, "XGBoost training control flow started.") + if abort_signal.triggered: + return + task = Task(name=SupportedTasks.TRAIN, data=Shareable()) + self.broadcast_and_wait( + task=task, + min_responses=self._world_size, + fl_ctx=fl_ctx, + wait_time_after_min_received=1, + abort_signal=abort_signal, + ) + if abort_signal.triggered: + return + + self.log_info(fl_ctx, "XGBoost training control flow finished.") diff --git a/demo/nvflare/custom/trainer.py b/demo/nvflare/custom/trainer.py new file mode 100644 index 000000000000..9403fec00215 --- /dev/null +++ b/demo/nvflare/custom/trainer.py @@ -0,0 +1,84 @@ +import os + +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode, FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal + +import xgboost as xgb +from xgboost import callback + + +class SupportedTasks(object): + TRAIN = "train" + + +class XGBoostTrainer(Executor): + def __init__(self, server_address: str, world_size: int, server_cert_path: str, + client_key_path: str, client_cert_path: str): + """Trainer for federated XGBoost. + + Args: + server_address: address for the gRPC server to connect to. + world_size: the number of sites. + server_cert_path: the path to the server certificate file. + client_key_path: the path to the client key file. + client_cert_path: the path to the client certificate file. + """ + super().__init__() + self._server_address = server_address + self._world_size = world_size + self._server_cert_path = server_cert_path + self._client_key_path = client_key_path + self._client_cert_path = client_cert_path + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, + abort_signal: Signal) -> Shareable: + self.log_info(fl_ctx, f"Executing {task_name}") + try: + if task_name == SupportedTasks.TRAIN: + self._do_training(fl_ctx) + return make_reply(ReturnCode.OK) + else: + self.log_error(fl_ctx, f"{task_name} is not a supported task.") + return make_reply(ReturnCode.TASK_UNKNOWN) + except BaseException as e: + self.log_exception(fl_ctx, + f"Task {task_name} failed. Exception: {e.__str__()}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + def _do_training(self, fl_ctx: FLContext): + client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME) + rank = int(client_name.split('-')[1]) - 1 + rabit_env = [ + f'federated_server_address={self._server_address}', + f'federated_world_size={self._world_size}', + f'federated_rank={rank}', + f'federated_server_cert={self._server_cert_path}', + f'federated_client_key={self._client_key_path}', + f'federated_client_cert={self._client_cert_path}' + ] + with xgb.rabit.RabitContext([e.encode() for e in rabit_env]): + # Load file, file will not be sharded in federated mode. + dtrain = xgb.DMatrix('agaricus.txt.train') + dtest = xgb.DMatrix('agaricus.txt.test') + + # Specify parameters via map, definition are same as c++ version + param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} + + # Specify validations set to watch performance + watchlist = [(dtest, 'eval'), (dtrain, 'train')] + num_round = 20 + + # Run training, all the features in training API is available. + bst = xgb.train(param, dtrain, num_round, evals=watchlist, + early_stopping_rounds=2, verbose_eval=False, + callbacks=[callback.EvaluationMonitor(rank=rank)]) + + # Save the model. + workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN) + run_dir = workspace.get_run_dir(run_number) + bst.save_model(os.path.join(run_dir, "test.model.json")) + xgb.rabit.tracker_print("Finished training\n") diff --git a/demo/nvflare/prepare_data.sh b/demo/nvflare/prepare_data.sh new file mode 100755 index 000000000000..838d0b28f845 --- /dev/null +++ b/demo/nvflare/prepare_data.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +set -e + +rm -fr ./agaricus* ./*.pem ./poc + +world_size=2 + +# Generate server and client certificates. +openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost" +openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost" + +# Split train and test files manually to simulate a federated environment. +split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.train agaricus.txt.train-site- +split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.test agaricus.txt.test-site- + +poc -n 2 +mkdir -p poc/admin/transfer/hello-xgboost +cp -fr config custom poc/admin/transfer/hello-xgboost +cp server-*.pem client-cert.pem poc/server/ +for id in $(eval echo "{1..$world_size}"); do + cp server-cert.pem client-*.pem poc/site-"$id"/ + cp agaricus.txt.train-site-"$id" poc/site-"$id"/agaricus.txt.train + cp agaricus.txt.test-site-"$id" poc/site-"$id"/agaricus.txt.test +done diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index ed7252ba117c..d18e9e095d0d 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -111,9 +111,7 @@ class FederatedEngine : public IEngine { void TrackerPrint(const std::string &msg) override { // simply print information into the tracker - if (GetRank() == 0) { - utils::Printf("%s", msg.c_str()); - } + utils::Printf("%s", msg.c_str()); } private: diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index 942893f0a32d..b54e26c9d550 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -224,25 +224,16 @@ def _assert_dask_support() -> None: LOGGER.warning(msg) -class RabitContext: +class RabitContext(rabit.RabitContext): """A context controlling rabit initialization and finalization.""" def __init__(self, args: List[bytes]) -> None: - self.args = args + super().__init__(args) worker = distributed.get_worker() self.args.append( ("DMLC_TASK_ID=[xgboost.dask]:" + str(worker.address)).encode() ) - def __enter__(self) -> None: - rabit.init(self.args) - assert rabit.is_distributed() - LOGGER.debug("-------------- rabit say hello ------------------") - - def __exit__(self, *args: List) -> None: - rabit.finalize() - LOGGER.debug("--------------- rabit say bye ------------------") - def concat(value: Any) -> Any: # pylint: disable=too-many-return-statements """To be replaced with dask builtin.""" diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py index a28448df8a67..465a5611a2d1 100644 --- a/python-package/xgboost/rabit.py +++ b/python-package/xgboost/rabit.py @@ -1,6 +1,7 @@ """Distributed XGBoost Rabit related API.""" import ctypes from enum import IntEnum, unique +import logging import pickle from typing import Any, TypeVar, Callable, Optional, cast, List, Union @@ -8,6 +9,8 @@ from .core import _LIB, c_str, _check_call +LOGGER = logging.getLogger("[xgboost.rabit]") + def _init_rabit() -> None: """internal library initializer.""" @@ -224,5 +227,21 @@ def version_number() -> int: return ret +class RabitContext: + """A context controlling rabit initialization and finalization.""" + + def __init__(self, args: List[bytes]) -> None: + self.args = args + + def __enter__(self) -> None: + init(self.args) + assert is_distributed() + LOGGER.debug("-------------- rabit say hello ------------------") + + def __exit__(self, *args: List) -> None: + finalize() + LOGGER.debug("--------------- rabit say bye ------------------") + + # initialization script _init_rabit()