Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove stop process. #143

Merged
merged 1 commit into from
Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 5 additions & 18 deletions include/rabit/internal/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,6 @@ namespace utils {
/*! \brief error message buffer length */
const int kPrintBuffer = 1 << 12;

/*! \brief we may want to keep the process alive when there are multiple workers
* co-locate in the same process */
extern bool STOP_PROCESS_ON_ERROR;

/* \brief Case-insensitive string comparison */
inline int CompareStringsCaseInsensitive(const char* s1, const char* s2) {
#ifdef _MSC_VER
Expand All @@ -89,26 +85,17 @@ inline bool StringToBool(const char* s) {
* \param msg error message
*/
inline void HandleAssertError(const char *msg) {
if (STOP_PROCESS_ON_ERROR) {
fprintf(stderr, "AssertError:%s, shutting down process\n", msg);
exit(-1);
} else {
fprintf(stderr, "AssertError:%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
fprintf(stderr,
"AssertError:%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
/*!
* \brief handling of Check error, caused by inappropriate input
* \param msg error message
*/
inline void HandleCheckError(const char *msg) {
if (STOP_PROCESS_ON_ERROR) {
fprintf(stderr, "%s, shutting down process\n", msg);
exit(-1);
} else {
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
fprintf(stderr, "%s, rabit is configured to keep process running\n", msg);
throw dmlc::Error(msg);
}
inline void HandlePrint(const char *msg) {
printf("%s", msg);
Expand Down
15 changes: 0 additions & 15 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@
#include "allreduce_base.h"

namespace rabit {

namespace utils {
bool STOP_PROCESS_ON_ERROR = true;
}

namespace engine {
// constructor
AllreduceBase::AllreduceBase(void) {
Expand Down Expand Up @@ -46,7 +41,6 @@ AllreduceBase::AllreduceBase(void) {
env_vars.push_back("DMLC_TRACKER_URI");
env_vars.push_back("DMLC_TRACKER_PORT");
env_vars.push_back("DMLC_WORKER_CONNECT_RETRY");
env_vars.push_back("DMLC_WORKER_STOP_PROCESS_ON_ERROR");
}

// initialization function
Expand Down Expand Up @@ -197,15 +191,6 @@ void AllreduceBase::SetParam(const char *name, const char *val) {
if (!strcmp(name, "DMLC_WORKER_CONNECT_RETRY")) {
connect_retry = atoi(val);
}
if (!strcmp(name, "DMLC_WORKER_STOP_PROCESS_ON_ERROR")) {
if (!strcmp(val, "true")) {
rabit::utils::STOP_PROCESS_ON_ERROR = true;
} else if (!strcmp(val, "false")) {
rabit::utils::STOP_PROCESS_ON_ERROR = false;
} else {
throw std::runtime_error("invalid value of DMLC_WORKER_STOP_PROCESS_ON_ERROR");
}
}
if (!strcmp(name, "rabit_bootstrap_cache")) {
rabit_bootstrap_cache = utils::StringToBool(val);
}
Expand Down
13 changes: 7 additions & 6 deletions src/allreduce_robust.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ int AllreduceRobust::GetBootstrapCache(const std::string &key, void* buf,
* \param size_prev_slice size of the previous slice i.e. slice of node (rank - 1) % world_size
* \param _file caller file name used to generate unique cache key
* \param _line caller line number used to generate unique cache key
* \param _caller caller function name used to generate unique cache key
*/
* \param _caller caller function name used to generate unique cache key
*/
void AllreduceRobust::Allgather(void *sendrecvbuf,
size_t total_size,
size_t slice_begin,
Expand Down Expand Up @@ -518,8 +518,8 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
}
// execute checkpoint, note: when checkpoint existing, load will not happen
_assert(RecoverExec(NULL, 0, ActionSummary::kCheckPoint,
ActionSummary::kSpecialOp, cur_cache_seq),
"check point must return true");
ActionSummary::kSpecialOp, cur_cache_seq),
"check point must return true");
// this is the critical region where we will change all the stored models
// increase version number
version_number += 1;
Expand Down Expand Up @@ -550,8 +550,9 @@ void AllreduceRobust::CheckPoint_(const Serializable *global_model,
delta = utils::GetTime() - start;
// log checkpoint ack latency
if (rabit_debug) {
utils::HandleLogInfo("[%d] checkpoint ack finished version %d, take %f seconds\n",
rank, version_number, delta);
utils::HandleLogInfo(
"[%d] checkpoint ack finished version %d, take %f seconds\n", rank,
version_number, delta);
}
}
/*!
Expand Down
5 changes: 0 additions & 5 deletions src/engine_empty.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
#include "rabit/internal/engine.h"

namespace rabit {

namespace utils {
bool STOP_PROCESS_ON_ERROR = true;
}

namespace engine {
/*! \brief EmptyEngine */
class EmptyEngine : public IEngine {
Expand Down
5 changes: 0 additions & 5 deletions src/engine_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
#include "rabit/internal/utils.h"

namespace rabit {

namespace utils {
bool STOP_PROCESS_ON_ERROR = true;
}

namespace engine {
/*! \brief implementation of engine using MPI */
class MPIEngine : public IEngine {
Expand Down
1 change: 1 addition & 0 deletions test/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ find_package(GTest REQUIRED)
add_executable(
unit_tests
test_io.cc
test_utils.cc
allreduce_robust_test.cc
allreduce_base_test.cc
allreduce_mock_test.cc
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/allreduce_mock_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ TEST(allreduce_mock, mock_allreduce)
char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 0;
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
EXPECT_THROW(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), dmlc::Error);
}

TEST(allreduce_mock, mock_broadcast)
Expand All @@ -32,5 +32,5 @@ TEST(allreduce_mock, mock_broadcast)
m.rank = 0;
m.version_number=1;
m.seq_counter=2;
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
EXPECT_THROW(m.Broadcast(nullptr,0,0), dmlc::Error);
}
7 changes: 4 additions & 3 deletions test/cpp/allreduce_mock_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <string>
#include <iostream>
#include <dmlc/logging.h>
#include "../../src/allreduce_mock.h"

TEST(allreduce_mock, mock_allreduce)
Expand All @@ -17,7 +18,7 @@ TEST(allreduce_mock, mock_allreduce)
char* argv[] = {cmd};
m.Init(1, argv);
m.rank = 0;
EXPECT_EXIT(m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr), ::testing::ExitedWithCode(255), "");
EXPECT_THROW({m.Allreduce(nullptr,0,0,nullptr,nullptr,nullptr);}, dmlc::Error);
}

TEST(allreduce_mock, mock_broadcast)
Expand All @@ -32,7 +33,7 @@ TEST(allreduce_mock, mock_broadcast)
m.rank = 0;
m.version_number=1;
m.seq_counter=2;
EXPECT_EXIT(m.Broadcast(nullptr,0,0), ::testing::ExitedWithCode(255), "");
EXPECT_THROW({m.Broadcast(nullptr,0,0);}, dmlc::Error);
}

TEST(allreduce_mock, mock_gather)
Expand All @@ -47,5 +48,5 @@ TEST(allreduce_mock, mock_gather)
m.rank = 3;
m.version_number=13;
m.seq_counter=22;
EXPECT_EXIT(m.Allgather(nullptr,0,0,0,0), ::testing::ExitedWithCode(255), "");
EXPECT_THROW({m.Allgather(nullptr,0,0,0,0);}, dmlc::Error);
}
6 changes: 6 additions & 0 deletions test/cpp/test_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <gtest/gtest.h>
#include <rabit/internal/utils.h>

TEST(Utils, Assert) {
EXPECT_THROW({rabit::utils::Assert(false, "foo");}, dmlc::Error);
}