New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Demo of federated learning using NVFlare #7879
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
ae2c2c5
federated learning demo using nvflare
rongou a0a9ee0
use the same name for data files
rongou 73d687e
print eval in both sites
rongou 4134649
clean up readme
rongou ac2504f
fix nvflare link
rongou 90aeee2
Merge remote-tracking branch 'upstream/master' into nvflare-demo
rongou d8ffca3
cleanup docs
rongou a310ddd
Merge remote-tracking branch 'upstream/master' into nvflare-demo
rongou 2ef1230
note on nvflare python version
rongou cda25f4
Merge remote-tracking branch 'upstream/master' into nvflare-demo
rongou 5e8b12f
move RabitContext to rabit.py
rongou 57713cc
Change back the name.
trivialfis File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Experimental Support of Federated XGBoost using NVFlare | ||
|
||
This directory contains a demo of Federated Learning using | ||
[NVFlare](https://nvidia.github.io/NVFlare/). | ||
|
||
To run the demo, first build XGBoost with the federated learning plugin enabled (see the | ||
[README](../../plugin/federated/README.md)). | ||
|
||
Install NVFlare (note that currently NVFlare only supports Python 3.8): | ||
```shell | ||
pip install nvflare | ||
``` | ||
|
||
Prepare the data: | ||
```shell | ||
./prepare_data.sh | ||
``` | ||
|
||
Start the NVFlare federated server: | ||
```shell | ||
./poc/server/startup/start.sh | ||
``` | ||
|
||
In another terminal, start the first worker: | ||
```shell | ||
./poc/site-1/startup/start.sh | ||
``` | ||
|
||
And the second worker: | ||
```shell | ||
./poc/site-2/startup/start.sh | ||
``` | ||
|
||
Then start the admin CLI, using `admin/admin` as username/password: | ||
```shell | ||
./poc/admin/startup/fl_admin.sh | ||
``` | ||
|
||
In the admin CLI, run the following commands: | ||
```shell | ||
upload_app hello-xgboost | ||
set_run_number 1 | ||
deploy_app hello-xgboost all | ||
start_app all | ||
``` | ||
|
||
Once the training finishes, the model file should be written into | ||
`./poc/site-1/run_1/test.model.json` and `./poc/site-2/run_1/test.model.json` | ||
respectively. | ||
|
||
Finally, shutdown everything from the admin CLI: | ||
```shell | ||
shutdown client | ||
shutdown server | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
{ | ||
"format_version": 2, | ||
"executors": [ | ||
{ | ||
"tasks": [ | ||
"train" | ||
], | ||
"executor": { | ||
"path": "trainer.XGBoostTrainer", | ||
"args": { | ||
"server_address": "localhost:9091", | ||
"world_size": 2, | ||
"server_cert_path": "server-cert.pem", | ||
"client_key_path": "client-key.pem", | ||
"client_cert_path": "client-cert.pem" | ||
} | ||
} | ||
} | ||
], | ||
"task_result_filters": [], | ||
"task_data_filters": [] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
{ | ||
"format_version": 2, | ||
"server": { | ||
"heart_beat_timeout": 600 | ||
}, | ||
"task_data_filters": [], | ||
"task_result_filters": [], | ||
"workflows": [ | ||
{ | ||
"id": "server_workflow", | ||
"path": "controller.XGBoostController", | ||
"args": { | ||
"port": 9091, | ||
"world_size": 2, | ||
"server_key_path": "server-key.pem", | ||
"server_cert_path": "server-cert.pem", | ||
"client_cert_path": "client-cert.pem" | ||
} | ||
} | ||
], | ||
"components": [] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
""" | ||
Example of training controller with NVFlare | ||
=========================================== | ||
""" | ||
import multiprocessing | ||
|
||
import xgboost.federated | ||
from nvflare.apis.client import Client | ||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.impl.controller import Controller, Task | ||
from nvflare.apis.shareable import Shareable | ||
from nvflare.apis.signal import Signal | ||
|
||
from trainer import SupportedTasks | ||
|
||
|
||
class XGBoostController(Controller): | ||
def __init__(self, port: int, world_size: int, server_key_path: str, | ||
server_cert_path: str, client_cert_path: str): | ||
"""Controller for federated XGBoost. | ||
|
||
Args: | ||
port: the port for the gRPC server to listen on. | ||
world_size: the number of sites. | ||
server_key_path: the path to the server key file. | ||
server_cert_path: the path to the server certificate file. | ||
client_cert_path: the path to the client certificate file. | ||
""" | ||
super().__init__() | ||
self._port = port | ||
self._world_size = world_size | ||
self._server_key_path = server_key_path | ||
self._server_cert_path = server_cert_path | ||
self._client_cert_path = client_cert_path | ||
self._server = None | ||
|
||
def start_controller(self, fl_ctx: FLContext): | ||
self._server = multiprocessing.Process( | ||
target=xgboost.federated.run_federated_server, | ||
args=(self._port, self._world_size, self._server_key_path, | ||
self._server_cert_path, self._client_cert_path)) | ||
self._server.start() | ||
|
||
def stop_controller(self, fl_ctx: FLContext): | ||
if self._server: | ||
self._server.terminate() | ||
|
||
def process_result_of_unknown_task(self, client: Client, task_name: str, | ||
client_task_id: str, result: Shareable, | ||
fl_ctx: FLContext): | ||
self.log_warning(fl_ctx, f"Unknown task: {task_name} from client {client.name}.") | ||
|
||
def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): | ||
self.log_info(fl_ctx, "XGBoost training control flow started.") | ||
if abort_signal.triggered: | ||
return | ||
task = Task(name=SupportedTasks.TRAIN, data=Shareable()) | ||
self.broadcast_and_wait( | ||
task=task, | ||
min_responses=self._world_size, | ||
fl_ctx=fl_ctx, | ||
wait_time_after_min_received=1, | ||
abort_signal=abort_signal, | ||
) | ||
if abort_signal.triggered: | ||
return | ||
|
||
self.log_info(fl_ctx, "XGBoost training control flow finished.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
|
||
from nvflare.apis.executor import Executor | ||
from nvflare.apis.fl_constant import ReturnCode, FLContextKey | ||
from nvflare.apis.fl_context import FLContext | ||
from nvflare.apis.shareable import Shareable, make_reply | ||
from nvflare.apis.signal import Signal | ||
|
||
import xgboost as xgb | ||
from xgboost import callback | ||
|
||
|
||
class SupportedTasks(object): | ||
TRAIN = "train" | ||
|
||
|
||
class XGBoostTrainer(Executor): | ||
def __init__(self, server_address: str, world_size: int, server_cert_path: str, | ||
client_key_path: str, client_cert_path: str): | ||
"""Trainer for federated XGBoost. | ||
|
||
Args: | ||
server_address: address for the gRPC server to connect to. | ||
world_size: the number of sites. | ||
server_cert_path: the path to the server certificate file. | ||
client_key_path: the path to the client key file. | ||
client_cert_path: the path to the client certificate file. | ||
""" | ||
super().__init__() | ||
self._server_address = server_address | ||
self._world_size = world_size | ||
self._server_cert_path = server_cert_path | ||
self._client_key_path = client_key_path | ||
self._client_cert_path = client_cert_path | ||
|
||
def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, | ||
abort_signal: Signal) -> Shareable: | ||
self.log_info(fl_ctx, f"Executing {task_name}") | ||
try: | ||
if task_name == SupportedTasks.TRAIN: | ||
self._do_training(fl_ctx) | ||
return make_reply(ReturnCode.OK) | ||
else: | ||
self.log_error(fl_ctx, f"{task_name} is not a supported task.") | ||
return make_reply(ReturnCode.TASK_UNKNOWN) | ||
except BaseException as e: | ||
self.log_exception(fl_ctx, | ||
f"Task {task_name} failed. Exception: {e.__str__()}") | ||
return make_reply(ReturnCode.EXECUTION_EXCEPTION) | ||
|
||
def _do_training(self, fl_ctx: FLContext): | ||
client_name = fl_ctx.get_prop(FLContextKey.CLIENT_NAME) | ||
rank = int(client_name.split('-')[1]) - 1 | ||
rabit_env = [ | ||
f'federated_server_address={self._server_address}', | ||
f'federated_world_size={self._world_size}', | ||
f'federated_rank={rank}', | ||
f'federated_server_cert={self._server_cert_path}', | ||
f'federated_client_key={self._client_key_path}', | ||
f'federated_client_cert={self._client_cert_path}' | ||
] | ||
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]): | ||
# Load file, file will not be sharded in federated mode. | ||
dtrain = xgb.DMatrix('agaricus.txt.train') | ||
dtest = xgb.DMatrix('agaricus.txt.test') | ||
|
||
# Specify parameters via map, definition are same as c++ version | ||
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'} | ||
|
||
# Specify validations set to watch performance | ||
watchlist = [(dtest, 'eval'), (dtrain, 'train')] | ||
num_round = 20 | ||
|
||
# Run training, all the features in training API is available. | ||
bst = xgb.train(param, dtrain, num_round, evals=watchlist, | ||
early_stopping_rounds=2, verbose_eval=False, | ||
callbacks=[callback.EvaluationMonitor(rank=rank)]) | ||
|
||
# Save the model. | ||
workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) | ||
run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN) | ||
run_dir = workspace.get_run_dir(run_number) | ||
bst.save_model(os.path.join(run_dir, "test.model.json")) | ||
xgb.rabit.tracker_print("Finished training\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#!/bin/bash | ||
|
||
set -e | ||
|
||
rm -fr ./agaricus* ./*.pem ./poc | ||
|
||
world_size=2 | ||
|
||
# Generate server and client certificates. | ||
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout server-key.pem -out server-cert.pem -subj "/C=US/CN=localhost" | ||
openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout client-key.pem -out client-cert.pem -subj "/C=US/CN=localhost" | ||
|
||
# Split train and test files manually to simulate a federated environment. | ||
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.train agaricus.txt.train-site- | ||
split -n l/${world_size} --numeric-suffixes=1 -a 1 ../data/agaricus.txt.test agaricus.txt.test-site- | ||
|
||
poc -n 2 | ||
mkdir -p poc/admin/transfer/hello-xgboost | ||
cp -fr config custom poc/admin/transfer/hello-xgboost | ||
cp server-*.pem client-cert.pem poc/server/ | ||
for id in $(eval echo "{1..$world_size}"); do | ||
cp server-cert.pem client-*.pem poc/site-"$id"/ | ||
cp agaricus.txt.train-site-"$id" poc/site-"$id"/agaricus.txt.train | ||
cp agaricus.txt.test-site-"$id" poc/site-"$id"/agaricus.txt.test | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to create a grpc server at the python layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we could write the grpc server in python, but it might have some limitations when it comes to threading. We are still talking with the nvflare team to figure out the details, so this could change in the future.