Skip to content

Commit

Permalink
Allow insecure gRPC connections for federated learning (#8181)
Browse files Browse the repository at this point in the history
* Allow insecure gRPC connections for federated learning

* format
  • Loading branch information
rongou committed Aug 19, 2022
1 parent 53d2a73 commit ad3bc0e
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 37 deletions.
17 changes: 9 additions & 8 deletions plugin/federated/engine_federated.cc
Expand Up @@ -42,8 +42,13 @@ class FederatedEngine : public IEngine {
}
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
server_address_.c_str(), world_size_, rank_);
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
if (server_cert_.empty() || client_key_.empty() || client_cert_.empty()) {
utils::Printf("Certificates not specified, turning off SSL.");
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_));
} else {
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
}
}

void Finalize() { client_.reset(); }
Expand Down Expand Up @@ -84,13 +89,9 @@ class FederatedEngine : public IEngine {
}
}

int LoadCheckPoint() override {
return 0;
}
int LoadCheckPoint() override { return 0; }

void CheckPoint() override {
version_number_ += 1;
}
void CheckPoint() override { version_number_ += 1; }

int VersionNumber() const override { return version_number_; }

Expand Down
2 changes: 1 addition & 1 deletion plugin/federated/federated_client.h
Expand Up @@ -33,7 +33,7 @@ class FederatedClient {
}()},
rank_{rank} {}

/** @brief Insecure client for testing only. */
/** @brief Insecure client for connecting to localhost only. */
FederatedClient(std::string const &server_address, int rank)
: stub_{Federated::NewStub(
grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))},
Expand Down
15 changes: 15 additions & 0 deletions plugin/federated/federated_server.cc
Expand Up @@ -231,5 +231,20 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
server->Wait();
}

void RunInsecureServer(int port, int world_size) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};

grpc::ServerBuilder builder;
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
<< world_size;

server->Wait();
}

} // namespace federated
} // namespace xgboost
2 changes: 2 additions & 0 deletions plugin/federated/federated_server.h
Expand Up @@ -40,5 +40,7 @@ class FederatedService final : public Federated::Service {
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file);

void RunInsecureServer(int port, int world_size);

} // namespace federated
} // namespace xgboost
31 changes: 17 additions & 14 deletions python-package/xgboost/federated.py
Expand Up @@ -6,9 +6,9 @@
def run_federated_server(
port: int,
world_size: int,
server_key_path: str,
server_cert_path: str,
client_cert_path: str,
server_key_path: str = "",
server_cert_path: str = "",
client_cert_path: str = "",
) -> None:
"""Run the Federated Learning server.
Expand All @@ -19,22 +19,25 @@ def run_federated_server(
world_size: int
The number of federated workers.
server_key_path: str
Path to the server private key file.
Path to the server private key file. SSL is turned off if empty.
server_cert_path: str
Path to the server certificate file.
Path to the server certificate file. SSL is turned off if empty.
client_cert_path: str
Path to the client certificate file.
Path to the client certificate file. SSL is turned off if empty.
"""
if build_info()["USE_FEDERATED"]:
_check_call(
_LIB.XGBRunFederatedServer(
port,
world_size,
c_str(server_key_path),
c_str(server_cert_path),
c_str(client_cert_path),
if not server_key_path or not server_cert_path or not client_cert_path:
_check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size))
else:
_check_call(
_LIB.XGBRunFederatedServer(
port,
world_size,
c_str(server_key_path),
c_str(server_cert_path),
c_str(client_cert_path),
)
)
)
else:
raise XGBoostError(
"XGBoost needs to be built with the federated learning plugin "
Expand Down
7 changes: 7 additions & 0 deletions src/c_api/c_api.cc
Expand Up @@ -1377,6 +1377,13 @@ XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_k
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
API_END();
}

// Run a server without SSL for local testing.
XGB_DLL int XGBRunInsecureFederatedServer(int port, int world_size) {
API_BEGIN();
federated::RunInsecureServer(port, world_size);
API_END();
}
#endif

// force link rabit
Expand Down
38 changes: 24 additions & 14 deletions tests/distributed/test_federated.py
Expand Up @@ -12,21 +12,28 @@
CLIENT_CERT = 'client-cert.pem'


def run_server(port: int, world_size: int) -> None:
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
CLIENT_CERT)
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
if with_ssl:
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
CLIENT_CERT)
else:
xgboost.federated.run_federated_server(port, world_size)


def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None:
# Always call this before using distributed module
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
rabit_env = [
f'federated_server_address=localhost:{port}',
f'federated_world_size={world_size}',
f'federated_rank={rank}',
f'federated_server_cert={SERVER_CERT}',
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
f'federated_rank={rank}'
]
if with_ssl:
rabit_env = rabit_env + [
f'federated_server_cert={SERVER_CERT}',
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
]

# Always call this before using distributed module
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
Expand All @@ -52,19 +59,20 @@ def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None:
xgb.rabit.tracker_print("Finished training\n")


def run_test(with_gpu: bool = False) -> None:
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
port = 9091
world_size = int(sys.argv[1])

server = multiprocessing.Process(target=run_server, args=(port, world_size))
server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl))
server.start()
time.sleep(1)
if not server.is_alive():
raise Exception("Error starting Federated Learning server")

workers = []
for rank in range(world_size):
worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank, with_gpu))
worker = multiprocessing.Process(target=run_worker,
args=(port, world_size, rank, with_ssl, with_gpu))
workers.append(worker)
worker.start()
for worker in workers:
Expand All @@ -73,5 +81,7 @@ def run_test(with_gpu: bool = False) -> None:


if __name__ == '__main__':
run_test()
run_test(with_gpu=True)
run_test(with_ssl=True, with_gpu=False)
run_test(with_ssl=False, with_gpu=False)
run_test(with_ssl=True, with_gpu=True)
run_test(with_ssl=False, with_gpu=True)

0 comments on commit ad3bc0e

Please sign in to comment.