Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error message when world size and rank are set as strings #8316

Merged
merged 7 commits into from Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion plugin/federated/CMakeLists.txt
Expand Up @@ -7,7 +7,7 @@ find_package(Threads)
add_library(federated_proto federated.proto)
target_link_libraries(federated_proto PUBLIC protobuf::libprotobuf gRPC::grpc gRPC::grpc++)
target_include_directories(federated_proto PUBLIC ${CMAKE_CURRENT_BINARY_DIR})
set_property(TARGET federated_proto PROPERTY POSITION_INDEPENDENT_CODE ON)
xgboost_target_properties(federated_proto)

get_target_property(grpc_cpp_plugin_location gRPC::grpc_cpp_plugin LOCATION)
protobuf_generate(TARGET federated_proto LANGUAGE cpp)
Expand Down
34 changes: 9 additions & 25 deletions plugin/federated/federated_communicator.h
Expand Up @@ -4,6 +4,7 @@
#pragma once
#include <xgboost/json.h>

#include "../../src/c_api/c_api_utils.h"
#include "../../src/collective/communicator.h"
#include "../../src/common/io.h"
#include "federated_client.h"
Expand Down Expand Up @@ -89,31 +90,14 @@ class FederatedCommunicator : public Communicator {
client_cert = value;
}

// Runtime configuration overrides.
auto const &j_server_address = config["federated_server_address"];
if (IsA<String const>(j_server_address)) {
server_address = get<String const>(j_server_address);
}
auto const &j_world_size = config["federated_world_size"];
if (IsA<Integer const>(j_world_size)) {
world_size = static_cast<int>(get<Integer const>(j_world_size));
}
auto const &j_rank = config["federated_rank"];
if (IsA<Integer const>(j_rank)) {
rank = static_cast<int>(get<Integer const>(j_rank));
}
auto const &j_server_cert = config["federated_server_cert"];
if (IsA<String const>(j_server_cert)) {
server_cert = get<String const>(j_server_cert);
}
auto const &j_client_key = config["federated_client_key"];
if (IsA<String const>(j_client_key)) {
client_key = get<String const>(j_client_key);
}
auto const &j_client_cert = config["federated_client_cert"];
if (IsA<String const>(j_client_cert)) {
client_cert = get<String const>(j_client_cert);
}
// Runtime configuration overrides, optional as users can specify them as env vars.
server_address = OptionalArg<String>(config, "federated_server_address", server_address);
world_size =
OptionalArg<Integer>(config, "federated_world_size", static_cast<Integer::Int>(world_size));
rank = OptionalArg<Integer>(config, "federated_rank", static_cast<Integer::Int>(rank));
server_cert = OptionalArg<String>(config, "federated_server_cert", server_cert);
client_key = OptionalArg<String>(config, "federated_client_key", client_key);
client_cert = OptionalArg<String>(config, "federated_client_cert", client_cert);

if (server_address.empty()) {
LOG(FATAL) << "Federated server address must be set.";
Expand Down
15 changes: 13 additions & 2 deletions src/c_api/c_api_utils.h
Expand Up @@ -248,21 +248,32 @@ inline void GenerateFeatureMap(Learner const *learner,

void XGBBuildInfoDevice(Json* p_info);

template <typename JT>
void TypeCheck(Json const &value, StringView name) {
using T = std::remove_const_t<JT> const;
if (!IsA<T>(value)) {
LOG(FATAL) << "Incorrect type for: `" << name << "`, expecting: `" << T{}.TypeStr()
<< "`, got: `" << value.GetValue().TypeStr() << "`.";
}
}

template <typename JT>
auto const &RequiredArg(Json const &in, std::string const &key, StringView func) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it == obj.cend() || IsA<Null>(it->second)) {
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`";
LOG(FATAL) << "Argument `" << key << "` is required for `" << func << "`.";
}
TypeCheck<JT>(it->second, StringView{key});
return get<std::remove_const_t<JT> const>(it->second);
}

template <typename JT, typename T>
auto const &OptionalArg(Json const &in, std::string const &key, T const &dft) {
auto const &obj = get<Object const>(in);
auto it = obj.find(key);
if (it != obj.cend()) {
if (it != obj.cend() && !IsA<Null>(it->second)) {
TypeCheck<JT>(it->second, StringView{key});
return get<std::remove_const_t<JT> const>(it->second);
}
return dft;
Expand Down
5 changes: 2 additions & 3 deletions src/collective/noop_communicator.h
Expand Up @@ -17,9 +17,8 @@ class NoOpCommunicator : public Communicator {
NoOpCommunicator() : Communicator(1, 0) {}
bool IsDistributed() const override { return false; }
bool IsFederated() const override { return false; }
void AllReduce(void *send_receive_buffer, std::size_t count, DataType data_type,
Operation op) override {}
void Broadcast(void *send_receive_buffer, std::size_t size, int root) override {}
void AllReduce(void *, std::size_t, DataType, Operation) override {}
void Broadcast(void *, std::size_t, int) override {}
std::string GetProcessorName() override { return ""; }
void Print(const std::string &message) override { LOG(CONSOLE) << message; }

Expand Down
32 changes: 32 additions & 0 deletions tests/cpp/c_api/test_c_api.cc
Expand Up @@ -324,4 +324,36 @@ TEST(CAPI, NullPtr) {
ASSERT_NE(pos, std::string::npos);
XGBAPISetLastError("");
}

TEST(CAPI, JArgs) {
{
Json args{Object{}};
args["key"] = String{"value"};
args["null"] = Null{};
auto value = OptionalArg<String>(args, "key", std::string{"foo"});
ASSERT_EQ(value, "value");
value = OptionalArg<String const>(args, "key", std::string{"foo"});
ASSERT_EQ(value, "value");

ASSERT_THROW({ OptionalArg<Number>(args, "key", 0.0f); }, dmlc::Error);
value = OptionalArg<String const>(args, "bar", std::string{"foo"});
ASSERT_EQ(value, "foo");
value = OptionalArg<String const>(args, "null", std::string{"foo"});
ASSERT_EQ(value, "foo");
}

{
Json args{Object{}};
args["key"] = String{"value"};
args["null"] = Null{};
auto value = RequiredArg<String>(args, "key", __func__);
ASSERT_EQ(value, "value");
value = RequiredArg<String const>(args, "key", __func__);
ASSERT_EQ(value, "value");

ASSERT_THROW({ RequiredArg<Integer>(args, "key", __func__); }, dmlc::Error);
ASSERT_THROW({ RequiredArg<String const>(args, "foo", __func__); }, dmlc::Error);
ASSERT_THROW({ RequiredArg<String>(args, "null", __func__); }, dmlc::Error);
}
}
} // namespace xgboost
36 changes: 36 additions & 0 deletions tests/cpp/plugin/test_federated_communicator.cc
Expand Up @@ -40,6 +40,9 @@ class FederatedCommunicatorTest : public ::testing::Test {
}

void TearDown() override {
while (!server_) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems odd, could you please share when it helps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was somehow needed to get the create test to work. Seems more trouble than its worth. Removed.

sleep(1);
}
server_->Shutdown();
server_thread_->join();
}
Expand Down Expand Up @@ -85,6 +88,28 @@ TEST(FederatedCommunicatorSimpleTest, ThrowOnRankTooBig) {
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(FederatedCommunicatorSimpleTest, ThrowOnWorldSizeNotInteger) {
auto construct = []() {
Json config{JsonObject()};
config["federated_server_address"] = kServerAddress;
config["federated_world_size"] = std::string("1");
config["federated_rank"] = Integer(0);
auto *comm = FederatedCommunicator::Create(config);
};
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(FederatedCommunicatorSimpleTest, ThrowOnRankNotInteger) {
auto construct = []() {
Json config{JsonObject()};
config["federated_server_address"] = kServerAddress;
config["federated_world_size"] = 1;
config["federated_rank"] = std::string("0");
auto *comm = FederatedCommunicator::Create(config);
};
EXPECT_THROW(construct(), dmlc::Error);
}

TEST(FederatedCommunicatorSimpleTest, GetWorldSizeAndRank) {
FederatedCommunicator comm{6, 3, kServerAddress};
EXPECT_EQ(comm.GetWorldSize(), 6);
Expand All @@ -96,6 +121,17 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
EXPECT_TRUE(comm.IsDistributed());
}

TEST_F(FederatedCommunicatorTest, Create) {
Json config{JsonObject()};
config["federated_server_address"] = kServerAddress;
config["federated_world_size"] = 1;
config["federated_rank"] = Integer(0);
auto *comm = FederatedCommunicator::Create(config);
EXPECT_EQ(1, comm->GetWorldSize());
EXPECT_EQ(0, comm->GetRank());
delete comm;
}

TEST_F(FederatedCommunicatorTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
Expand Down