Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow insecure gRPC connections for federated learning #8181

Merged
merged 3 commits into from Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 9 additions & 8 deletions plugin/federated/engine_federated.cc
Expand Up @@ -42,8 +42,13 @@ class FederatedEngine : public IEngine {
}
utils::Printf("Connecting to federated server %s, world size %d, rank %d",
server_address_.c_str(), world_size_, rank_);
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
if (server_cert_.empty() || client_key_.empty() || client_cert_.empty()) {
utils::Printf("Certificates not specified, turning off SSL.");
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_));
} else {
client_.reset(new xgboost::federated::FederatedClient(server_address_, rank_, server_cert_,
client_key_, client_cert_));
}
}

void Finalize() { client_.reset(); }
Expand Down Expand Up @@ -84,13 +89,9 @@ class FederatedEngine : public IEngine {
}
}

int LoadCheckPoint() override {
return 0;
}
int LoadCheckPoint() override { return 0; }

void CheckPoint() override {
version_number_ += 1;
}
void CheckPoint() override { version_number_ += 1; }

int VersionNumber() const override { return version_number_; }

Expand Down
2 changes: 1 addition & 1 deletion plugin/federated/federated_client.h
Expand Up @@ -33,7 +33,7 @@ class FederatedClient {
}()},
rank_{rank} {}

/** @brief Insecure client for testing only. */
/** @brief Insecure client for connecting to localhost only. */
FederatedClient(std::string const &server_address, int rank)
: stub_{Federated::NewStub(
grpc::CreateChannel(server_address, grpc::InsecureChannelCredentials()))},
Expand Down
15 changes: 15 additions & 0 deletions plugin/federated/federated_server.cc
Expand Up @@ -231,5 +231,20 @@ void RunServer(int port, int world_size, char const* server_key_file, char const
server->Wait();
}

void RunInsecureServer(int port, int world_size) {
std::string const server_address = "0.0.0.0:" + std::to_string(port);
FederatedService service{world_size};

grpc::ServerBuilder builder;
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
LOG(CONSOLE) << "Insecure federated server listening on " << server_address << ", world size "
<< world_size;

server->Wait();
}

} // namespace federated
} // namespace xgboost
2 changes: 2 additions & 0 deletions plugin/federated/federated_server.h
Expand Up @@ -40,5 +40,7 @@ class FederatedService final : public Federated::Service {
void RunServer(int port, int world_size, char const* server_key_file, char const* server_cert_file,
char const* client_cert_file);

void RunInsecureServer(int port, int world_size);

} // namespace federated
} // namespace xgboost
31 changes: 17 additions & 14 deletions python-package/xgboost/federated.py
Expand Up @@ -6,9 +6,9 @@
def run_federated_server(
port: int,
world_size: int,
server_key_path: str,
server_cert_path: str,
client_cert_path: str,
server_key_path: str = "",
server_cert_path: str = "",
client_cert_path: str = "",
) -> None:
"""Run the Federated Learning server.

Expand All @@ -19,22 +19,25 @@ def run_federated_server(
world_size: int
The number of federated workers.
server_key_path: str
Path to the server private key file.
Path to the server private key file. SSL is turned off if empty.
server_cert_path: str
Path to the server certificate file.
Path to the server certificate file. SSL is turned off if empty.
client_cert_path: str
Path to the client certificate file.
Path to the client certificate file. SSL is turned off if empty.
"""
if build_info()["USE_FEDERATED"]:
_check_call(
_LIB.XGBRunFederatedServer(
port,
world_size,
c_str(server_key_path),
c_str(server_cert_path),
c_str(client_cert_path),
if not server_key_path or not server_cert_path or not client_cert_path:
_check_call(_LIB.XGBRunInsecureFederatedServer(port, world_size))
else:
_check_call(
_LIB.XGBRunFederatedServer(
port,
world_size,
c_str(server_key_path),
c_str(server_cert_path),
c_str(client_cert_path),
)
)
)
else:
raise XGBoostError(
"XGBoost needs to be built with the federated learning plugin "
Expand Down
7 changes: 7 additions & 0 deletions src/c_api/c_api.cc
Expand Up @@ -1377,6 +1377,13 @@ XGB_DLL int XGBRunFederatedServer(int port, int world_size, char const *server_k
federated::RunServer(port, world_size, server_key_path, server_cert_path, client_cert_path);
API_END();
}

// Run a server without SSL for local testing.
XGB_DLL int XGBRunInsecureFederatedServer(int port, int world_size) {
API_BEGIN();
federated::RunInsecureServer(port, world_size);
API_END();
}
#endif

// force link rabit
Expand Down
38 changes: 24 additions & 14 deletions tests/distributed/test_federated.py
Expand Up @@ -12,21 +12,28 @@
CLIENT_CERT = 'client-cert.pem'


def run_server(port: int, world_size: int) -> None:
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
CLIENT_CERT)
def run_server(port: int, world_size: int, with_ssl: bool) -> None:
if with_ssl:
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
CLIENT_CERT)
else:
xgboost.federated.run_federated_server(port, world_size)


def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None:
# Always call this before using distributed module
def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
rabit_env = [
f'federated_server_address=localhost:{port}',
f'federated_world_size={world_size}',
f'federated_rank={rank}',
f'federated_server_cert={SERVER_CERT}',
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
f'federated_rank={rank}'
]
if with_ssl:
rabit_env = rabit_env + [
f'federated_server_cert={SERVER_CERT}',
f'federated_client_key={CLIENT_KEY}',
f'federated_client_cert={CLIENT_CERT}'
]

# Always call this before using distributed module
with xgb.rabit.RabitContext([e.encode() for e in rabit_env]):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d' % rank)
Expand All @@ -52,19 +59,20 @@ def run_worker(port: int, world_size: int, rank: int, with_gpu: bool) -> None:
xgb.rabit.tracker_print("Finished training\n")


def run_test(with_gpu: bool = False) -> None:
def run_test(with_ssl: bool = True, with_gpu: bool = False) -> None:
port = 9091
world_size = int(sys.argv[1])

server = multiprocessing.Process(target=run_server, args=(port, world_size))
server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl))
server.start()
time.sleep(1)
if not server.is_alive():
raise Exception("Error starting Federated Learning server")

workers = []
for rank in range(world_size):
worker = multiprocessing.Process(target=run_worker, args=(port, world_size, rank, with_gpu))
worker = multiprocessing.Process(target=run_worker,
args=(port, world_size, rank, with_ssl, with_gpu))
workers.append(worker)
worker.start()
for worker in workers:
Expand All @@ -73,5 +81,7 @@ def run_test(with_gpu: bool = False) -> None:


if __name__ == '__main__':
run_test()
run_test(with_gpu=True)
run_test(with_ssl=True, with_gpu=False)
run_test(with_ssl=False, with_gpu=False)
run_test(with_ssl=True, with_gpu=True)
run_test(with_ssl=False, with_gpu=True)