Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo of federated learning using NVFlare #7879

Merged
merged 12 commits into from May 14, 2022
55 changes: 55 additions & 0 deletions demo/nvflare/README.md
@@ -0,0 +1,55 @@
# Federated XGBoost Demo using NVFlare
trivialfis marked this conversation as resolved.
Show resolved Hide resolved

This directory contains a demo of Federated Learning using
[NVFlare](https://nvidia.github.io/NVFlare/).

To run the demo, first build XGBoost with the federated learning plugin enabled (see the
[README](../../plugin/federated/README.md)).

Install NVFlare:
```shell
pip install nvflare
```

Prepare the data:
```shell
./prepare_data.sh
```

Start the NVFlare federated server:
```shell
./poc/server/startup/start.sh
```

In another terminal, start the first worker:
```shell
./poc/site-1/startup/start.sh
```

And the second worker:
```shell
./poc/site-2/startup/start.sh
```

Then start the admin CLI, using `admin/admin` as username/password:
```shell
./poc/admin/startup/fl_admin.sh
```

In the admin CLI, run the following commands:
```shell
upload_app hello-xgboost
set_run_number 1
deploy_app hello-xgboost all
start_app all
```

Once the training finishes, the model file should be written into
`./poc/site-1/run_1/test.model.json` and `./poc/site-2/run_1/test.model.json`
respectively.

Finally, shutdown everything from the admin CLI:
```shell
shutdown client
shutdown server
```
22 changes: 22 additions & 0 deletions demo/nvflare/config/config_fed_client.json
@@ -0,0 +1,22 @@
{
"format_version": 2,
"executors": [
{
"tasks": [
"train"
],
"executor": {
"path": "trainer.XGBoostTrainer",
"args": {
"server_address": "localhost:9091",
"world_size": 2,
"server_cert_path": "server-cert.pem",
"client_key_path": "client-key.pem",
"client_cert_path": "client-cert.pem"
}
}
}
],
"task_result_filters": [],
"task_data_filters": []
}
22 changes: 22 additions & 0 deletions demo/nvflare/config/config_fed_server.json
@@ -0,0 +1,22 @@
{
"format_version": 2,
"server": {
"heart_beat_timeout": 600
},
"task_data_filters": [],
"task_result_filters": [],
"workflows": [
{
"id": "server_workflow",
"path": "controller.XGBoostController",
"args": {
"port": 9091,
"world_size": 2,
"server_key_path": "server-key.pem",
"server_cert_path": "server-cert.pem",
"client_cert_path": "client-cert.pem"
}
}
],
"components": []
}
68 changes: 68 additions & 0 deletions demo/nvflare/custom/controller.py
@@ -0,0 +1,68 @@
"""
Example of training controller with NVFlare
===========================================
"""
import multiprocessing

import xgboost.federated
from nvflare.apis.client import Client
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import Controller, Task
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal

from trainer import SupportedTasks


class XGBoostController(Controller):
def __init__(self, port: int, world_size: int, server_key_path: str,
server_cert_path: str, client_cert_path: str):
"""Controller for federated XGBoost.

Args:
port: the port for the gRPC server to listen on.
world_size: the number of sites.
server_key_path: the path to the server key file.
server_cert_path: the path to the server certificate file.
client_cert_path: the path to the client certificate file.
"""
super().__init__()
self._port = port
self._world_size = world_size
self._server_key_path = server_key_path
self._server_cert_path = server_cert_path
self._client_cert_path = client_cert_path
self._server = None

def start_controller(self, fl_ctx: FLContext):
self._server = multiprocessing.Process(
target=xgboost.federated.run_federated_server,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to create a grpc server at the python layer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we could write the grpc server in python, but it might have some limitations when it comes to threading. We are still talking with the nvflare team to figure out the details, so this could change in the future.

args=(self._port, self._world_size, self._server_key_path,
self._server_cert_path, self._client_cert_path))
self._server.start()

def stop_controller(self, fl_ctx: FLContext):
if self._server:
self._server.terminate()

def process_result_of_unknown_task(self, client: Client, task_name: str,
client_task_id: str, result: Shareable,
fl_ctx: FLContext):
self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.")

def control_flow(self, abort_signal: Signal, fl_ctx: FLContext):
self.log_info(fl_ctx, "XGBoost training control flow started.")
if abort_signal.triggered:
return
task = Task(name=SupportedTasks.TRAIN, data=Shareable())
self.broadcast_and_wait(
task=task,
min_responses=self._world_size,
fl_ctx=fl_ctx,
wait_time_after_min_received=1,
abort_signal=abort_signal,
)
if abort_signal.triggered:
return

self.log_info(fl_ctx, "XGBoost training control flow finished.")
89 changes: 89 additions & 0 deletions demo/nvflare/custom/trainer.py
@@ -0,0 +1,89 @@
import os

from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode, FLContextKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal

import xgboost as xgb
from xgboost import callback


class SupportedTasks(object):
TRAIN = "train"


class XGBoostTrainer(Executor):
def __init__(self, server_address: str, world_size: int, server_cert_path: str,
client_key_path: str, client_cert_path: str):
"""Trainer for federated XGBoost.

Args:
server_address: address for the gRPC server to connect to.
world_size: the number of sites.
server_cert_path: the path to the server certificate file.
client_key_path: the path to the client key file.
client_cert_path: the path to the client certificate file.
"""
super().__init__()
self._server_address = server_address
self._world_size = world_size
self._server_cert_path = server_cert_path
self._client_key_path = client_key_path
self._client_cert_path = client_cert_path

def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext,
abort_signal: Signal) -> Shareable:
self.log_info(fl_ctx, f"Executing {task_name}")
try:
if task_name == SupportedTasks.TRAIN:
self._do_training(fl_ctx)
return make_reply(ReturnCode.OK)
else:
self.log_error(fl_ctx, f"{task_name} is not a supported task.")
return make_reply(ReturnCode.TASK_UNKNOWN)
except BaseException as e:
self.log_exception(fl_ctx,
f"Task {task_name} failed. Exception: {e.__str__()}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

def _do_training(self, fl_ctx: FLContext):
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME)
rank = int(client_name.split('-')[1]) - 1
rabit_env = [
f'federated_server_address={self._server_address}',
f'federated_world_size={self._world_size}',
f'federated_rank={rank}',
f'federated_server_cert={self._server_cert_path}',
f'federated_client_key={self._client_key_path}',
f'federated_client_cert={self._client_cert_path}'
]
xgb.rabit.init([e.encode() for e in rabit_env])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move the RabitContext from dask module to rabit module and reuse it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. But this changes the class name in dask. Is that what we want? Maybe we can keep the same name RabitContext in the dask module?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I made the change to your PR.


# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is not using the split data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In prepare_data.sh we copy each split into the site-specific directory. This probably looks more like the real federated environment.

dtest = xgb.DMatrix('agaricus.txt.test')

# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}

# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 20

# Run training, all the features in training API is available.
bst = xgb.train(param, dtrain, num_round, evals=watchlist,
early_stopping_rounds=2, verbose_eval=False,
callbacks=[callback.EvaluationMonitor(rank=rank)])

# Save the model.
workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT)
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN)
run_dir = workspace.get_run_dir(run_number)
bst.save_model(os.path.join(run_dir, "test.model.json"))
xgb.rabit.tracker_print("Finished training\n")

# Notify the tracker all training has been successful
# This is only needed in distributed training.
xgb.rabit.finalize()
25 changes: 25 additions & 0 deletions demo/nvflare/prepare_data.sh
@@ -0,0 +1,25 @@
#!/bin/bash

set -e

rm -fr ./agaricus* ./*.pem ./poc

world_size=2

# Generate server and client certificates.
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost"
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost"

# Split train and test files manually to simulate a federated environment.
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.train agaricus.txt.train-site-
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.test agaricus.txt.test-site-

poc -n 2
mkdir -p poc/admin/transfer/hello-xgboost
cp -fr config custom poc/admin/transfer/hello-xgboost
cp server-*.pem client-cert.pem poc/server/
for id in $(eval echo "{1..$world_size}"); do
cp server-cert.pem client-*.pem poc/site-"$id"/
cp agaricus.txt.train-site-"$id" poc/site-"$id"/agaricus.txt.train
cp agaricus.txt.test-site-"$id" poc/site-"$id"/agaricus.txt.test
done
4 changes: 1 addition & 3 deletions plugin/federated/engine_federated.cc
Expand Up @@ -111,9 +111,7 @@ class FederatedEngine : public IEngine {

void TrackerPrint(const std::string &msg) override {
// simply print information into the tracker
if (GetRank() == 0) {
utils::Printf("%s", msg.c_str());
}
utils::Printf("%s", msg.c_str());
}

private:
Expand Down