Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error message when world size and rank are set as strings #8316

Merged
merged 7 commits into from Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugin/federated/CMakeLists.txt
Expand Up @@ -7,7 +7,7 @@ find_package(Threads)
add_library(federated_proto federated.proto)
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON)
xgboost_target_properties(federated_proto)

get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
protobuf_generate(TARGET federated_proto LANGUAGE cpp)
Expand Down
34 changes: 9 additions & 25 deletions plugin/federated/federated_communicator.h
Expand Up @@ -4,6 +4,7 @@
#pragma once
#include <xgboost/json.h>

#include "../../src/c_api/c_api_utils.h"
#include "../../src/collective/communicator.h"
#include "../../src/common/io.h"
#include "federated_client.h"
Expand Down Expand Up @@ -89,31 +90,14 @@ class FederatedCommunicator : public Communicator {
client_cert = value;
}

// Runtime configuration overrides.
auto const &j_server_address = config["federated_server_address"];
if (IsA<String const>(j_server_address)) {
server_address = get<String const>(j_server_address);
}
auto const &j_world_size = config["federated_world_size"];
if (IsA<Integer const>(j_world_size)) {
world_size = static_cast<int>(get<Integer const>(j_world_size));
}
auto const &j_rank = config["federated_rank"];
if (IsA<Integer const>(j_rank)) {
rank = static_cast<int>(get<Integer const>(j_rank));
}
auto const &j_server_cert = config["federated_server_cert"];
if (IsA<String const>(j_server_cert)) {
server_cert = get<String const>(j_server_cert);
}
auto const &j_client_key = config["federated_client_key"];
if (IsA<String const>(j_client_key)) {
client_key = get<String const>(j_client_key);
}
auto const &j_client_cert = config["federated_client_cert"];
if (IsA<String const>(j_client_cert)) {
client_cert = get<String const>(j_client_cert);
}
// Runtime configuration overrides, optional as users can specify them as env vars.
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
world_size =
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
server_cert = OptionalArg<String>(config, "federated_server_cert", server_cert);
client_key = OptionalArg<String>(config, "federated_client_key", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert", client_cert);

if (server_address.empty()) {
LOG(FATAL) << "Federated server address must be set.";
Expand Down
15 changes: 13 additions & 2 deletions src/c_api/c_api_utils.h
Expand Up @@ -248,21 +248,32 @@ inline void GenerateFeatureMap(Learner const *learner,

void XGBBuildInfoDevice(Json* p_info);

template <typename JT>
void TypeCheck(Json const &value, StringView name) {
using T = std::remove_const_t<JT> const;
if (!IsA<T>(value)) {
LOG(FATAL) << "Incorrect type for: `" << name << "`, expecting: `" << T{}.TypeStr()
<< "`, got: `" << value.GetValue().TypeStr() << "`.";
}
}

template <typename JT>
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it == obj.cend() || IsA<Null>(it->second)) {
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`";
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`.";
}
TypeCheck<JT>(it->second, StringView{key});
return get<std::remove_const_t<JT> const>(it->second);
}

template <typename JT, typename T>
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it != obj.cend()) {
if (it != obj.cend() && !IsA<Null>(it->second)) {
TypeCheck<JT>(it->second, StringView{key});
return get<std::remove_const_t<JT> const>(it->second);
}
return dft;
Expand Down
5 changes: 2 additions & 3 deletions src/collective/noop_communicator.h
Expand Up @@ -17,9 +17,8 @@ class NoOpCommunicator : public Communicator {
NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; }
bool IsFederated() const override { return false; }
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
void AllReduce(void *, std::size_t, DataType, Operation) override {}
void Broadcast(void *, std::size_t, int) override {}
std::string GetProcessorName() override { return ""; }
void Print(const std::string &message) override { LOG(CONSOLE) << message; }

Expand Down
32 changes: 32 additions & 0 deletions tests/cpp/c_api/test_c_api.cc
Expand Up @@ -324,4 +324,36 @@ TEST(CAPI, NullPtr) {
ASSERT_NE(pos, std::string::npos);
XGBAPISetLastError("");
}

TEST(CAPI, JArgs) {
{
Json args{Object{}};
args["key"] = String{"value"};
args["null"] = Null{};
auto value = OptionalArg<String>(args, "key", std::string{"foo"});
ASSERT_EQ(value, "value");
value = OptionalArg<String const>(args, "key", std::string{"foo"});
ASSERT_EQ(value, "value");

ASSERT_THROW({ OptionalArg<Number>(args, "key", 0.0f); }, dmlc::Error);
value = OptionalArg<String const>(args, "bar", std::string{"foo"});
ASSERT_EQ(value, "foo");
value = OptionalArg<String const>(args, "null", std::string{"foo"});
ASSERT_EQ(value, "foo");
}

{
Json args{Object{}};
args["key"] = String{"value"};
args["null"] = Null{};
auto value = RequiredArg<String>(args, "key", __func__);
ASSERT_EQ(value, "value");
value = RequiredArg<String const>(args, "key", __func__);
ASSERT_EQ(value, "value");

ASSERT_THROW({ RequiredArg<Integer>(args, "key", __func__); }, dmlc::Error);
ASSERT_THROW({ RequiredArg<String const>(args, "foo", __func__); }, dmlc::Error);
ASSERT_THROW({ RequiredArg<String>(args, "null", __func__); }, dmlc::Error);
}
}
} // namespace xgboost
22 changes: 22 additions & 0 deletions tests/cpp/plugin/test_federated_communicator.cc
Expand Up @@ -85,6 +85,28 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) {
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
auto construct = []() {
Json config{JsonObject()};
config["federated_server_address"] = kServerAddress;
config["federated_world_size"] = std::string("1");
config["federated_rank"] = Integer(0);
auto *comm = FederatedCommunicator::Create(config);
};
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
auto construct = []() {
Json config{JsonObject()};
config["federated_server_address"] = kServerAddress;
config["federated_world_size"] = 1;
config["federated_rank"] = std::string("0");
auto *comm = FederatedCommunicator::Create(config);
};
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) {
FederatedCommunicator comm{6, 3, kServerAddress};
EXPECT_EQ(comm.GetWorldSize(), 6);
Expand Down