From ad3bc0edeedc25c917cbf4699df61c91cd9fce65 Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 18 Aug 2022 21:16:14 -0700 Subject: [PATCH] Allow insecure gRPC connections for federated learning (#8181) * Allow insecure gRPC connections for federated learning * format --- plugin/federated/engine_federated.cc | 17 +++++++------ plugin/federated/federated_client.h | 2 +- plugin/federated/federated_server.cc | 15 +++++++++++ plugin/federated/federated_server.h | 2 ++ python-package/xgboost/federated.py | 31 +++++++++++++---------- src/c_api/c_api.cc | 7 +++++ tests/distributed/test_federated.py | 38 ++++++++++++++++++---------- 7 files changed, 75 insertions(+), 37 deletions(-) diff --git a/plugin/federated/engine_federated.cc b/plugin/federated/engine_federated.cc index 9b43c3997cc3..8f2447cf2b90 100644 --- a/plugin/federated/engine_federated.cc +++ b/plugin/federated/engine_federated.cc @@ -42,8 +42,13 @@ class FederatedEngine : public IEngine { } utils::Printf("Connecting to federated server %s, world size %d, rank %d", server_address_.c_str(), world_size_, rank_); - client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_, - client_key_, client_cert_)); + if (server_cert_.empty() || client_key_.empty() || client_cert_.empty()) { + utils::Printf("Certificates not specified, turning off SSL."); + client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_)); + } else { + client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_, + client_key_, client_cert_)); + } } void Finalize() { client_.reset(); } @@ -84,13 +89,9 @@ class FederatedEngine : public IEngine { } } - int LoadCheckPoint() override { - return 0; - } + int LoadCheckPoint() override { return 0; } - void CheckPoint() override { - version_number_ += 1; - } + void CheckPoint() override { version_number_ += 1; } int VersionNumber() const override { return version_number_; } diff --git a/plugin/federated/federated_client.h b/plugin/federated/federated_client.h index b00bf0f0f225..ab9fc895e342 100644 --- a/plugin/federated/federated_client.h +++ b/plugin/federated/federated_client.h @@ -33,7 +33,7 @@ class FederatedClient { }()}, rank_{rank} {} - /** @brief Insecure client for testing only. */ + /** @brief Insecure client for connecting to localhost only. */ FederatedClient(std::string const &server_address, int rank) : stub_{Federated::NewStub( grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))}, diff --git a/plugin/federated/federated_server.cc b/plugin/federated/federated_server.cc index c38bdc4b8171..9e11fd51ee0b 100644 --- a/plugin/federated/federated_server.cc +++ b/plugin/federated/federated_server.cc @@ -231,5 +231,20 @@ void RunServer(int port, int world_size, char const* server_key_file, char const server->Wait(); } +void RunInsecureServer(int port, int world_size) { + std::string const server_address = "0.0.0.0:" + std::to_string(port); + FederatedService service{world_size}; + + grpc::ServerBuilder builder; + builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); + builder.RegisterService(&service); + std::unique_ptr server(builder.BuildAndStart()); + LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size " + << world_size; + + server->Wait(); +} + } // namespace federated } // namespace xgboost diff --git a/plugin/federated/federated_server.h b/plugin/federated/federated_server.h index 108d78d47a96..122499d0d9c0 100644 --- a/plugin/federated/federated_server.h +++ b/plugin/federated/federated_server.h @@ -40,5 +40,7 @@ class FederatedService final : public Federated::Service { void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file, char const* client_cert_file); +void RunInsecureServer(int port, int world_size); + } // namespace federated } // namespace xgboost diff --git a/python-package/xgboost/federated.py b/python-package/xgboost/federated.py index f520ecb2e300..0214e4e2066a 100644 --- a/python-package/xgboost/federated.py +++ b/python-package/xgboost/federated.py @@ -6,9 +6,9 @@ def run_federated_server( port: int, world_size: int, - server_key_path: str, - server_cert_path: str, - client_cert_path: str, + server_key_path: str = "", + server_cert_path: str = "", + client_cert_path: str = "", ) -> None: """Run the Federated Learning server. @@ -19,22 +19,25 @@ def run_federated_server( world_size: int The number of federated workers. server_key_path: str - Path to the server private key file. + Path to the server private key file. SSL is turned off if empty. server_cert_path: str - Path to the server certificate file. + Path to the server certificate file. SSL is turned off if empty. client_cert_path: str - Path to the client certificate file. + Path to the client certificate file. SSL is turned off if empty. """ if build_info()["USE_FEDERATED"]: - _check_call( - _LIB.XGBRunFederatedServer( - port, - world_size, - c_str(server_key_path), - c_str(server_cert_path), - c_str(client_cert_path), + if not server_key_path or not server_cert_path or not client_cert_path: + _check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size)) + else: + _check_call( + _LIB.XGBRunFederatedServer( + port, + world_size, + c_str(server_key_path), + c_str(server_cert_path), + c_str(client_cert_path), + ) ) - ) else: raise XGBoostError( "XGBoost needs to be built with the federated learning plugin " diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 7fae49bd29c0..698b8fb8fba9 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1377,6 +1377,13 @@ XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_k federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path); API_END(); } + +// Run a server without SSL for local testing. +XGB_DLL int XGBRunInsecureFederatedServer(int port, int world_size) { + API_BEGIN(); + federated::RunInsecureServer(port, world_size); + API_END(); +} #endif // force link rabit diff --git a/tests/distributed/test_federated.py b/tests/distributed/test_federated.py index cddd104e922c..d83d7f2292fe 100644 --- a/tests/distributed/test_federated.py +++ b/tests/distributed/test_federated.py @@ -12,21 +12,28 @@ CLIENT_CERT = 'client-cert.pem' -def run_server(port: int, world_size: int) -> None: - xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT, - CLIENT_CERT) +def run_server(port: int, world_size: int, with_ssl: bool) -> None: + if with_ssl: + xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT, + CLIENT_CERT) + else: + xgboost.federated.run_federated_server(port, world_size) -def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None: - # Always call this before using distributed module +def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None: rabit_env = [ f'federated_server_address=localhost:{port}', f'federated_world_size={world_size}', - f'federated_rank={rank}', - f'federated_server_cert={SERVER_CERT}', - f'federated_client_key={CLIENT_KEY}', - f'federated_client_cert={CLIENT_CERT}' + f'federated_rank={rank}' ] + if with_ssl: + rabit_env = rabit_env + [ + f'federated_server_cert={SERVER_CERT}', + f'federated_client_key={CLIENT_KEY}', + f'federated_client_cert={CLIENT_CERT}' + ] + + # Always call this before using distributed module with xgb.rabit.RabitContext([e.encode() for e in rabit_env]): # Load file, file will not be sharded in federated mode. dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank) @@ -52,11 +59,11 @@ def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None: xgb.rabit.tracker_print("Finished training\n") -def run_test(with_gpu: bool = False) -> None: +def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None: port = 9091 world_size = int(sys.argv[1]) - server = multiprocessing.Process(target=run_server, args=(port, world_size)) + server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl)) server.start() time.sleep(1) if not server.is_alive(): @@ -64,7 +71,8 @@ def run_test(with_gpu: bool = False) -> None: workers = [] for rank in range(world_size): - worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank, with_gpu)) + worker = multiprocessing.Process(target=run_worker, + args=(port, world_size, rank, with_ssl, with_gpu)) workers.append(worker) worker.start() for worker in workers: @@ -73,5 +81,7 @@ def run_test(with_gpu: bool = False) -> None: if __name__ == '__main__': - run_test() - run_test(with_gpu=True) + run_test(with_ssl=True, with_gpu=False) + run_test(with_ssl=False, with_gpu=False) + run_test(with_ssl=True, with_gpu=True) + run_test(with_ssl=False, with_gpu=True)