From 85f8a9bb753e95288392b23e1b5a96a5640b196f Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Thu, 6 Oct 2022 11:34:57 -0700 Subject: [PATCH 1/5] Allow world size and rank to be specified as strings --- plugin/federated/federated_communicator.h | 6 +++++ .../cpp/plugin/test_federated_communicator.cc | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index cb1eb0b8109f..65fe6a24de1e 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -98,10 +98,16 @@ class FederatedCommunicator : public Communicator { if (IsA(j_world_size)) { world_size = static_cast(get(j_world_size)); } + if (IsA(j_world_size)) { + world_size = std::stoi(get(j_world_size)); + } auto const &j_rank = config["federated_rank"]; if (IsA(j_rank)) { rank = static_cast(get(j_rank)); } + if (IsA(j_rank)) { + rank = std::stoi(get(j_rank)); + } auto const &j_server_cert = config["federated_server_cert"]; if (IsA(j_server_cert)) { server_cert = get(j_server_cert); diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 2d9f233db573..6a53f2a039c3 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -40,6 +40,9 @@ class FederatedCommunicatorTest : public ::testing::Test { } void TearDown() override { + while (!server_) { + sleep(1); + } server_->Shutdown(); server_thread_->join(); } @@ -96,6 +99,28 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) { EXPECT_TRUE(comm.IsDistributed()); } +TEST_F(FederatedCommunicatorTest, Create) { + Json config{JsonObject()}; + config["federated_server_address"] = kServerAddress; + config["federated_world_size"] = std::string("1"); + config["federated_rank"] = std::string("0"); + auto *comm = FederatedCommunicator::Create(config); + EXPECT_EQ(1, comm->GetWorldSize()); + EXPECT_EQ(0, comm->GetRank()); + delete comm; +} + +TEST_F(FederatedCommunicatorTest, CreateFromIntegers) { + Json config{JsonObject()}; + config["federated_server_address"] = kServerAddress; + config["federated_world_size"] = 1; + config["federated_rank"] = Integer(0); + auto *comm = FederatedCommunicator::Create(config); + EXPECT_EQ(1, comm->GetWorldSize()); + EXPECT_EQ(0, comm->GetRank()); + delete comm; +} + TEST_F(FederatedCommunicatorTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) { From 5fedba865c1193f8d02030ed06d5008538ba1f4a Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Mon, 10 Oct 2022 10:17:21 -0700 Subject: [PATCH 2/5] enforce integer --- plugin/federated/federated_communicator.h | 4 +-- .../cpp/plugin/test_federated_communicator.cc | 33 ++++++++++++------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 65fe6a24de1e..25c750f4c0f8 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -99,14 +99,14 @@ class FederatedCommunicator : public Communicator { world_size = static_cast(get(j_world_size)); } if (IsA(j_world_size)) { - world_size = std::stoi(get(j_world_size)); + LOG(FATAL) << "Federated world size must be an integer."; } auto const &j_rank = config["federated_rank"]; if (IsA(j_rank)) { rank = static_cast(get(j_rank)); } if (IsA(j_rank)) { - rank = std::stoi(get(j_rank)); + LOG(FATAL) << "Federated rank must be an integer."; } auto const &j_server_cert = config["federated_server_cert"]; if (IsA(j_server_cert)) { diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 6a53f2a039c3..525564067b88 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -88,6 +88,28 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) { EXPECT_THROW(construct(), dmlc::Error); } +TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { + auto construct = []() { + Json config{JsonObject()}; + config["federated_server_address"] = kServerAddress; + config["federated_world_size"] = std::string("1"); + config["federated_rank"] = Integer(0); + auto *comm = FederatedCommunicator::Create(config); + }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) { + auto construct = []() { + Json config{JsonObject()}; + config["federated_server_address"] = kServerAddress; + config["federated_world_size"] = 1; + config["federated_rank"] = std::string("0"); + auto *comm = FederatedCommunicator::Create(config); + }; + EXPECT_THROW(construct(), dmlc::Error); +} + TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) { FederatedCommunicator comm{6, 3, kServerAddress}; EXPECT_EQ(comm.GetWorldSize(), 6); @@ -100,17 +122,6 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) { } TEST_F(FederatedCommunicatorTest, Create) { - Json config{JsonObject()}; - config["federated_server_address"] = kServerAddress; - config["federated_world_size"] = std::string("1"); - config["federated_rank"] = std::string("0"); - auto *comm = FederatedCommunicator::Create(config); - EXPECT_EQ(1, comm->GetWorldSize()); - EXPECT_EQ(0, comm->GetRank()); - delete comm; -} - -TEST_F(FederatedCommunicatorTest, CreateFromIntegers) { Json config{JsonObject()}; config["federated_server_address"] = kServerAddress; config["federated_world_size"] = 1; From 8d4841dc61381f79585148fa8f4aa78292a42ea3 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Tue, 11 Oct 2022 16:14:35 +0800 Subject: [PATCH 3/5] Use optional argument instead. --- plugin/federated/CMakeLists.txt | 2 +- plugin/federated/federated_communicator.h | 40 +++++------------------ src/c_api/c_api_utils.h | 13 +++++++- src/collective/noop_communicator.h | 5 ++- tests/cpp/c_api/test_c_api.cc | 32 ++++++++++++++++++ 5 files changed, 56 insertions(+), 36 deletions(-) diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index 24ba47abfb8e..8dbf8227f887 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -7,7 +7,7 @@ find_package(Threads) add_library(federated_proto federated.proto) target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) -set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON) +xgboost_target_properties(federated_proto) get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION) protobuf_generate(TARGET federated_proto LANGUAGE cpp) diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index 25c750f4c0f8..2dc50962a296 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -4,6 +4,7 @@ #pragma once #include +#include "../../src/c_api/c_api_utils.h" #include "../../src/collective/communicator.h" #include "../../src/common/io.h" #include "federated_client.h" @@ -89,37 +90,14 @@ class FederatedCommunicator : public Communicator { client_cert = value; } - // Runtime configuration overrides. - auto const &j_server_address = config["federated_server_address"]; - if (IsA(j_server_address)) { - server_address = get(j_server_address); - } - auto const &j_world_size = config["federated_world_size"]; - if (IsA(j_world_size)) { - world_size = static_cast(get(j_world_size)); - } - if (IsA(j_world_size)) { - LOG(FATAL) << "Federated world size must be an integer."; - } - auto const &j_rank = config["federated_rank"]; - if (IsA(j_rank)) { - rank = static_cast(get(j_rank)); - } - if (IsA(j_rank)) { - LOG(FATAL) << "Federated rank must be an integer."; - } - auto const &j_server_cert = config["federated_server_cert"]; - if (IsA(j_server_cert)) { - server_cert = get(j_server_cert); - } - auto const &j_client_key = config["federated_client_key"]; - if (IsA(j_client_key)) { - client_key = get(j_client_key); - } - auto const &j_client_cert = config["federated_client_cert"]; - if (IsA(j_client_cert)) { - client_cert = get(j_client_cert); - } + // Runtime configuration overrides, optional as users can specify them as env vars. + server_address = OptionalArg(config, "federated_server_address", server_address); + world_size = + OptionalArg(config, "federated_world_size", static_cast(world_size)); + rank = OptionalArg(config, "federated_rank", static_cast(rank)); + server_cert = OptionalArg(config, "federated_server_cert", server_cert); + client_key = OptionalArg(config, "federated_client_key", client_key); + client_cert = OptionalArg(config, "federated_client_cert", client_cert); if (server_address.empty()) { LOG(FATAL) << "Federated server address must be set."; diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index ba23f765e641..0eae04bf46cc 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -248,6 +248,15 @@ inline void GenerateFeatureMap(Learner const *learner, void XGBBuildInfoDevice(Json* p_info); +template +void TypeCheck(Json const &value, StringView name) { + using T = std::remove_const_t const; + if (!IsA(value)) { + LOG(FATAL) << "Incorrect type for: " << name << ", expecting: " << T{}.TypeStr() + << ", got:" << value.GetValue().TypeStr(); + } +} + template auto const &RequiredArg(Json const &in, std::string const &key, StringView func) { auto const &obj = get(in); @@ -255,6 +264,7 @@ auto const &RequiredArg(Json const &in, std::string const &key, StringView func) if (it == obj.cend() || IsA(it->second)) { LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`"; } + TypeCheck(it->second, StringView{key}); return get const>(it->second); } @@ -262,7 +272,8 @@ template auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) { auto const &obj = get(in); auto it = obj.find(key); - if (it != obj.cend()) { + if (it != obj.cend() && !IsA(it->second)) { + TypeCheck(it->second, StringView{key}); return get const>(it->second); } return dft; diff --git a/src/collective/noop_communicator.h b/src/collective/noop_communicator.h index 5d4fb89f409c..cad6da029530 100644 --- a/src/collective/noop_communicator.h +++ b/src/collective/noop_communicator.h @@ -17,9 +17,8 @@ class NoOpCommunicator : public Communicator { NoOpCommunicator() : Communicator(1, 0) {} bool IsDistributed() const override { return false; } bool IsFederated() const override { return false; } - void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override {} - void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {} + void AllReduce(void *, std::size_t, DataType, Operation) override {} + void Broadcast(void *, std::size_t, int) override {} std::string GetProcessorName() override { return ""; } void Print(const std::string &message) override { LOG(CONSOLE) << message; } diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 30e2b5cd0475..ea1f893164f8 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -324,4 +324,36 @@ TEST(CAPI, NullPtr) { ASSERT_NE(pos, std::string::npos); XGBAPISetLastError(""); } + +TEST(CAPI, JArgs) { + { + Json args{Object{}}; + args["key"] = String{"value"}; + args["null"] = Null{}; + auto value = OptionalArg(args, "key", std::string{"foo"}); + ASSERT_EQ(value, "value"); + value = OptionalArg(args, "key", std::string{"foo"}); + ASSERT_EQ(value, "value"); + + ASSERT_THROW({ OptionalArg(args, "key", 0.0f); }, dmlc::Error); + value = OptionalArg(args, "bar", std::string{"foo"}); + ASSERT_EQ(value, "foo"); + value = OptionalArg(args, "null", std::string{"foo"}); + ASSERT_EQ(value, "foo"); + } + + { + Json args{Object{}}; + args["key"] = String{"value"}; + args["null"] = Null{}; + auto value = RequiredArg(args, "key", __func__); + ASSERT_EQ(value, "value"); + value = RequiredArg(args, "key", __func__); + ASSERT_EQ(value, "value"); + + ASSERT_THROW({ RequiredArg(args, "key", __func__); }, dmlc::Error); + ASSERT_THROW({ RequiredArg(args, "foo", __func__); }, dmlc::Error); + ASSERT_THROW({ RequiredArg(args, "null", __func__); }, dmlc::Error); + } +} } // namespace xgboost From 77eb80d4708f173e74d04d442a44e00c48d044c4 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Tue, 11 Oct 2022 16:26:05 +0800 Subject: [PATCH 4/5] error message. --- src/c_api/c_api_utils.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 0eae04bf46cc..b7b2c93034d3 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -252,8 +252,8 @@ template void TypeCheck(Json const &value, StringView name) { using T = std::remove_const_t const; if (!IsA(value)) { - LOG(FATAL) << "Incorrect type for: " << name << ", expecting: " << T{}.TypeStr() - << ", got:" << value.GetValue().TypeStr(); + LOG(FATAL) << "Incorrect type for: `" << name << "`, expecting: `" << T{}.TypeStr() + << "`, got: `" << value.GetValue().TypeStr() << "`."; } } @@ -262,7 +262,7 @@ auto const &RequiredArg(Json const &in, std::string const &key, StringView func) auto const &obj = get(in); auto it = obj.find(key); if (it == obj.cend() || IsA(it->second)) { - LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`"; + LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`."; } TypeCheck(it->second, StringView{key}); return get const>(it->second); From 5fcb36020146da932056437daec5a105b9b8bf0e Mon Sep 17 00:00:00 2001 From: Rong Ou Date: Tue, 11 Oct 2022 13:51:50 -0700 Subject: [PATCH 5/5] remove create test --- tests/cpp/plugin/test_federated_communicator.cc | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 525564067b88..3cfa15fb1eb7 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -40,9 +40,6 @@ class FederatedCommunicatorTest : public ::testing::Test { } void TearDown() override { - while (!server_) { - sleep(1); - } server_->Shutdown(); server_thread_->join(); } @@ -121,17 +118,6 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) { EXPECT_TRUE(comm.IsDistributed()); } -TEST_F(FederatedCommunicatorTest, Create) { - Json config{JsonObject()}; - config["federated_server_address"] = kServerAddress; - config["federated_world_size"] = 1; - config["federated_rank"] = Integer(0); - auto *comm = FederatedCommunicator::Create(config); - EXPECT_EQ(1, comm->GetWorldSize()); - EXPECT_EQ(0, comm->GetRank()); - delete comm; -} - TEST_F(FederatedCommunicatorTest, Allreduce) { std::vector threads; for (auto rank = 0; rank < kWorldSize; rank++) {