Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error message when world size and rank are set as strings #8316

Merged
merged 7 commits into from Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions plugin/federated/federated_communicator.h
Expand Up @@ -98,10 +98,16 @@ class FederatedCommunicator : public Communicator {
if (IsA<Integer const>(j_world_size)) {
world_size = static_cast<int>(get<Integer const>(j_world_size));
}
if (IsA<String const>(j_world_size)) {
world_size = std::stoi(get<String const>(j_world_size));
}
auto const &j_rank = config["federated_rank"];
if (IsA<Integer const>(j_rank)) {
rank = static_cast<int>(get<Integer const>(j_rank));
}
if (IsA<String const>(j_rank)) {
rank = std::stoi(get<String const>(j_rank));
}
auto const &j_server_cert = config["federated_server_cert"];
if (IsA<String const>(j_server_cert)) {
server_cert = get<String const>(j_server_cert);
Expand Down
25 changes: 25 additions & 0 deletions tests/cpp/plugin/test_federated_communicator.cc
Expand Up @@ -40,6 +40,9 @@ class FederatedCommunicatorTest : public ::testing::Test {
}

void TearDown() override {
while (!server_) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems odd, could you please share when it helps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was somehow needed to get the create test to work. Seems more trouble than its worth. Removed.

sleep(1);
}
server_->Shutdown();
server_thread_->join();
}
Expand Down Expand Up @@ -96,6 +99,28 @@ TEST(FederatedCommunicatorSimpleTest, IsDistributed) {
EXPECT_TRUE(comm.IsDistributed());
}

TEST_F(FederatedCommunicatorTest, Create) {
Json config{JsonObject()};
config["federated_server_address"] = kServerAddress;
config["federated_world_size"] = std::string("1");
config["federated_rank"] = std::string("0");
auto *comm = FederatedCommunicator::Create(config);
EXPECT_EQ(1, comm->GetWorldSize());
EXPECT_EQ(0, comm->GetRank());
delete comm;
}

TEST_F(FederatedCommunicatorTest, CreateFromIntegers) {
Json config{JsonObject()};
config["federated_server_address"] = kServerAddress;
config["federated_world_size"] = 1;
config["federated_rank"] = Integer(0);
auto *comm = FederatedCommunicator::Create(config);
EXPECT_EQ(1, comm->GetWorldSize());
EXPECT_EQ(0, comm->GetRank());
delete comm;
}

TEST_F(FederatedCommunicatorTest, Allreduce) {
std::vector<std::thread> threads;
for (auto rank = 0; rank < kWorldSize; rank++) {
Expand Down