diff --git a/plugin/federated/CMakeLists.txt b/plugin/federated/CMakeLists.txt index 24ba47abfb8e..8dbf8227f887 100644 --- a/plugin/federated/CMakeLists.txt +++ b/plugin/federated/CMakeLists.txt @@ -7,7 +7,7 @@ find_package(Threads) add_library(federated_proto federated.proto) target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++) target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR}) -set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON) +xgboost_target_properties(federated_proto) get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION) protobuf_generate(TARGET federated_proto LANGUAGE cpp) diff --git a/plugin/federated/federated_communicator.h b/plugin/federated/federated_communicator.h index cb1eb0b8109f..2dc50962a296 100644 --- a/plugin/federated/federated_communicator.h +++ b/plugin/federated/federated_communicator.h @@ -4,6 +4,7 @@ #pragma once #include +#include "../../src/c_api/c_api_utils.h" #include "../../src/collective/communicator.h" #include "../../src/common/io.h" #include "federated_client.h" @@ -89,31 +90,14 @@ class FederatedCommunicator : public Communicator { client_cert = value; } - // Runtime configuration overrides. - auto const &j_server_address = config["federated_server_address"]; - if (IsA(j_server_address)) { - server_address = get(j_server_address); - } - auto const &j_world_size = config["federated_world_size"]; - if (IsA(j_world_size)) { - world_size = static_cast(get(j_world_size)); - } - auto const &j_rank = config["federated_rank"]; - if (IsA(j_rank)) { - rank = static_cast(get(j_rank)); - } - auto const &j_server_cert = config["federated_server_cert"]; - if (IsA(j_server_cert)) { - server_cert = get(j_server_cert); - } - auto const &j_client_key = config["federated_client_key"]; - if (IsA(j_client_key)) { - client_key = get(j_client_key); - } - auto const &j_client_cert = config["federated_client_cert"]; - if (IsA(j_client_cert)) { - client_cert = get(j_client_cert); - } + // Runtime configuration overrides, optional as users can specify them as env vars. + server_address = OptionalArg(config, "federated_server_address", server_address); + world_size = + OptionalArg(config, "federated_world_size", static_cast(world_size)); + rank = OptionalArg(config, "federated_rank", static_cast(rank)); + server_cert = OptionalArg(config, "federated_server_cert", server_cert); + client_key = OptionalArg(config, "federated_client_key", client_key); + client_cert = OptionalArg(config, "federated_client_cert", client_cert); if (server_address.empty()) { LOG(FATAL) << "Federated server address must be set."; diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index ba23f765e641..b7b2c93034d3 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -248,13 +248,23 @@ inline void GenerateFeatureMap(Learner const *learner, void XGBBuildInfoDevice(Json* p_info); +template +void TypeCheck(Json const &value, StringView name) { + using T = std::remove_const_t const; + if (!IsA(value)) { + LOG(FATAL) << "Incorrect type for: `" << name << "`, expecting: `" << T{}.TypeStr() + << "`, got: `" << value.GetValue().TypeStr() << "`."; + } +} + template auto const &RequiredArg(Json const &in, std::string const &key, StringView func) { auto const &obj = get(in); auto it = obj.find(key); if (it == obj.cend() || IsA(it->second)) { - LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`"; + LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`."; } + TypeCheck(it->second, StringView{key}); return get const>(it->second); } @@ -262,7 +272,8 @@ template auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) { auto const &obj = get(in); auto it = obj.find(key); - if (it != obj.cend()) { + if (it != obj.cend() && !IsA(it->second)) { + TypeCheck(it->second, StringView{key}); return get const>(it->second); } return dft; diff --git a/src/collective/noop_communicator.h b/src/collective/noop_communicator.h index 5d4fb89f409c..cad6da029530 100644 --- a/src/collective/noop_communicator.h +++ b/src/collective/noop_communicator.h @@ -17,9 +17,8 @@ class NoOpCommunicator : public Communicator { NoOpCommunicator() : Communicator(1, 0) {} bool IsDistributed() const override { return false; } bool IsFederated() const override { return false; } - void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type, - Operation op) override {} - void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {} + void AllReduce(void *, std::size_t, DataType, Operation) override {} + void Broadcast(void *, std::size_t, int) override {} std::string GetProcessorName() override { return ""; } void Print(const std::string &message) override { LOG(CONSOLE) << message; } diff --git a/tests/cpp/c_api/test_c_api.cc b/tests/cpp/c_api/test_c_api.cc index 30e2b5cd0475..ea1f893164f8 100644 --- a/tests/cpp/c_api/test_c_api.cc +++ b/tests/cpp/c_api/test_c_api.cc @@ -324,4 +324,36 @@ TEST(CAPI, NullPtr) { ASSERT_NE(pos, std::string::npos); XGBAPISetLastError(""); } + +TEST(CAPI, JArgs) { + { + Json args{Object{}}; + args["key"] = String{"value"}; + args["null"] = Null{}; + auto value = OptionalArg(args, "key", std::string{"foo"}); + ASSERT_EQ(value, "value"); + value = OptionalArg(args, "key", std::string{"foo"}); + ASSERT_EQ(value, "value"); + + ASSERT_THROW({ OptionalArg(args, "key", 0.0f); }, dmlc::Error); + value = OptionalArg(args, "bar", std::string{"foo"}); + ASSERT_EQ(value, "foo"); + value = OptionalArg(args, "null", std::string{"foo"}); + ASSERT_EQ(value, "foo"); + } + + { + Json args{Object{}}; + args["key"] = String{"value"}; + args["null"] = Null{}; + auto value = RequiredArg(args, "key", __func__); + ASSERT_EQ(value, "value"); + value = RequiredArg(args, "key", __func__); + ASSERT_EQ(value, "value"); + + ASSERT_THROW({ RequiredArg(args, "key", __func__); }, dmlc::Error); + ASSERT_THROW({ RequiredArg(args, "foo", __func__); }, dmlc::Error); + ASSERT_THROW({ RequiredArg(args, "null", __func__); }, dmlc::Error); + } +} } // namespace xgboost diff --git a/tests/cpp/plugin/test_federated_communicator.cc b/tests/cpp/plugin/test_federated_communicator.cc index 2d9f233db573..3cfa15fb1eb7 100644 --- a/tests/cpp/plugin/test_federated_communicator.cc +++ b/tests/cpp/plugin/test_federated_communicator.cc @@ -85,6 +85,28 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) { EXPECT_THROW(construct(), dmlc::Error); } +TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) { + auto construct = []() { + Json config{JsonObject()}; + config["federated_server_address"] = kServerAddress; + config["federated_world_size"] = std::string("1"); + config["federated_rank"] = Integer(0); + auto *comm = FederatedCommunicator::Create(config); + }; + EXPECT_THROW(construct(), dmlc::Error); +} + +TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) { + auto construct = []() { + Json config{JsonObject()}; + config["federated_server_address"] = kServerAddress; + config["federated_world_size"] = 1; + config["federated_rank"] = std::string("0"); + auto *comm = FederatedCommunicator::Create(config); + }; + EXPECT_THROW(construct(), dmlc::Error); +} + TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) { FederatedCommunicator comm{6, 3, kServerAddress}; EXPECT_EQ(comm.GetWorldSize(), 6);